Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use b2c.view_labels() for cell-type or specific gene visualization? #24

Open
Pancreas-Pratik opened this issue Nov 17, 2024 · 2 comments

Comments

@Pancreas-Pratik
Copy link

Pancreas-Pratik commented Nov 17, 2024

Thank you very much for bin2cell.

@ktpolanski is b2c.view_labels() in 0.3.2 the unformalized custom function @nadavyayon referred to and was using at 18 minutes and 35 seconds (18:35) in the September 2024 - 10X Webinar here: https://pages.10xgenomics.com/WBR-2024-09-EVENT-VIS-VISIUM-HD-ANALYSIS-PARENT_LP.html?

  1. I want to do exactly this type of visualization with this confidence score in viridis for cell-types the way that was done (high-resolution and fast/light-weight viewing of different parts of the hi-res image) in the 10X webinar.
  2. I also want to do another separate visualization "style" with a canonical gene marker expression (instead of cell-type confidence score) also in viridis with and also with cell-type segment boundaries colored with border_color.

It looks like this can be easily done after I read through: #21 (comment), #21 (comment) (b2c.view_labels()), and https://bin2cell.readthedocs.io/en/latest/bin2cell.view_labels.html#bin2cell.view_labels.

Also how and where could this full scale image be visualized the way in the 10X webinar? I am guessing napari can do this. Would there be a way to save the full scale image output to disk for visualization?

If you still have some example code, could you share this, please?
It would help me do (1), and I could figure out how to do (2) from (1). I am more familiar with R and Seurat, than python.

@ktpolanski
Copy link
Contributor

The function in 0.3.2 inspired Nadav to write whatever he used to make the visualisation for the seminar, which I recently rewrote to run quickly. This has not been distributed anywhere at this time, I shared a beta version with Nadav for him to play around with. Once there's a reasonable form to share, I'll get back to you with it.

@ktpolanski
Copy link
Contributor

ktpolanski commented Dec 3, 2024

Alright, I think I have something shareable. Stick this into a Jupyter notebook cell, or just execute wherever it is you execute code:

#new deps to add!
import skimage.segmentation
import seaborn as sns
import matplotlib

import matplotlib.pyplot as plt
import scipy.sparse
import scipy.stats
import bin2cell as b2c
import numpy as np

from PIL import Image
#setting needed so PIL can load the large TIFFs
Image.MAX_IMAGE_PIXELS = None

def overlay_onto_img(img, labels_sparse, cdata, key, common_objects, 
                      fill_label_weight=1,
                      make_legends=True
                     ):
    #are we in obs or in var?
    if key in cdata.obs.columns:
        #continuous or categorical?
        if ("float" in cdata.obs[key].dtype.name) or ("int" in cdata.obs[key].dtype.name):
            #we've got a continous
            #subset on common_objects, so the order matches them
            #also need to turn to string to match cdata.obs_names
            vals = cdata.obs.loc[[str(i) for i in common_objects], key].values
            #we'll continue processing shortly outside the ifs
        elif "category" in cdata.obs[key].dtype.name:
            #we've got a categorical
            #use tab20 to try to get some different ones
            #convert to uint8 for internal consistency
            fill_palette = (np.array(sns.color_palette("tab20"))*255).astype(np.uint8)
            #pull out the category indices for each of the common objects
            #and store them in a numpy array that can be indexed on the object in int form
            cats = np.zeros(np.max(common_objects)+1, dtype=np.int32)
            #need to turn common_objects back to string so cdata can be subset on them
            cats[common_objects] = cdata.obs.loc[[str(i) for i in common_objects], key].cat.codes
            #store the original present category codes for legend purposes
            cats_unique_original = np.unique(cats[common_objects])
            #reset the cats[common_objects] to start from 0 and go up by 1
            #which may avoid some palette overlaps in some corner case scenarios
            #calling this on [5,7,7,9] yields [1,2,2,3] which is what we want
            #except then shift it back by 1 so it starts from 0
            cats[common_objects] = scipy.stats.rankdata(cats[common_objects], method="dense") - 1
            #now we have a master list of pixels with objects to show
            #.row is [:,0], .col is [:,1]
            #extract the existing values from the image
            #and simultaneously get a fill colour by doing % on number of fill colours
            #pull out the category index by subsetting on the actual object ID
            #weight the two together to get the new pixel value
            img[labels_sparse.row, labels_sparse.col, :] = \
                (1-fill_label_weight) * img[labels_sparse.row, labels_sparse.col, :] + \
                fill_label_weight * fill_palette[cats[labels_sparse.data] % fill_palette.shape[0], :]
            #set up legend
            #figsze is largely irrelevant because of bbox_inches='tight' when saving
            fig, ax = plt.subplots(figsize=(5, 2))
            ax.axis("off")
            #alright, there's some stuff going on here. let's explain
            #the categories don't actually get subset when you pull out a chunk of the vector
            #which kinda makes sense but is annoying
            #meanwhile we want to minimise our ID footprint to try to use the colour map nicely
            #as such, earlier we made cats_unique_original, having all the codes from the objects
            #and then we reset the cats to start at 0 and go up by 1
            #now we can get a correspondence of those to cats_unique_original
            #by doing a subset and np.unique(), as this is monotonically preserved
            #still need to % the category code like before
            legend_patches = [
                matplotlib.patches.Patch(color=fill_palette[i % fill_palette.shape[0],:]/255.0, 
                                         label=cdata.obs[key].cat.categories[j]
                                        )
                for i, j in zip(np.unique(cats[common_objects]), cats_unique_original)
            ]
            ax.legend(handles=legend_patches, loc="center", title=key, frameon=False)
            #close the thing so it doesn't randomly show. still there though
            plt.close(fig)
            #okay we're happy. return to recycle continuous processing code
            return img, fig
        else:
            #we've got to raise an error
            raise ValueError("``cdata.obs['"+key+"']`` must be a float, int, or categorical")
    elif key in cdata.var_names:
        #gene, continuous
        #fast enough to just subset the cdata, as always turn to string
        #then regardless if it's sparse or dense data need to .toarray().flatten()
        #if it's sparse then this turns it dense
        #if it's dense then it gets out of ArrayView back into a normal array
        #and then flattened regardless
        vals = cdata[[str(i) for i in common_objects]][:,key].X.toarray().flatten()
    else:
        #we've got a whiff
        raise ValueError("'"+key+"' not found in ``cdata.obs`` or ``cdata.var``")
    #we're out here. so we're processing a continuous, be it obs or var
    #set up legend while vals are on their actual scale
    #figsze is largely irrelevant because of bbox_inches='tight' when saving
    fig, ax = plt.subplots(figsize=(5, 2))
    ax.axis("off")
    colormap = matplotlib.colormaps.get_cmap("viridis")
    norm = matplotlib.colors.Normalize(vmin=np.min(vals), vmax=np.max(vals))
    sm = matplotlib.cm.ScalarMappable(cmap=colormap, norm=norm)
    fig.colorbar(sm, ax=ax, orientation="horizontal", label=key)
    #close the thing so it doesn't randomly show. still there though
    plt.close(fig)
    #for continuous operations, we need a 0-1 scaled vector of values
    vals = (vals-np.min(vals))/(np.max(vals)-np.min(vals))
    #construct a fill palette by busting out a colormap
    #and then getting its values at vals, and sending them to a prepared numpy array
    #that can then be subset on the object ID to get its matching RGB of the continuous value
    fill_palette = np.zeros((np.max(common_objects)+1, 3))
    #send to common_objects, matching vals order. also convert to uint8 for consistency
    fill_palette[common_objects, :] = (matplotlib.colormaps.get_cmap('viridis')(vals)[:,:3]*255).astype(np.uint8)
    #now we have a master list of pixels with objects to show
    #.row is [:,0], .col is [:,1]
    #extract the existing values from the image
    #and simultaneously get a fill colour by subsetting the palette on the label ID
    #no fancy % needed here as each object has its own fill value prepared
    #weight the two together to get the new pixel value
    img[labels_sparse.row, labels_sparse.col, :] = \
        (1-fill_label_weight) * img[labels_sparse.row, labels_sparse.col, :] + \
        fill_label_weight * fill_palette[labels_sparse.data, :]
    return img, fig
    
    
def cell_label_render(image_path, labels_npz_path, cdata, 
                      fill_key=None, 
                      border_key=None, 
                      crop=None, 
                      stardist_normalize=False, 
                      fill_label_weight=1, 
                      thicken_border=True
                     ):
    #load the sparse labels
    labels_sparse = scipy.sparse.load_npz(labels_npz_path)
    #determine memory efficient dtype to load the image as
    #if we'll be normalising, we want np.float16 for optimal RAM footprint
    #otherwise use np.uint8
    if stardist_normalize:
        dtype = np.float16
    else:
        dtype = np.uint8
    if crop is None:
        #this will load greyscale as 3 channel, which is what we want here
        img = b2c.load_image(image_path, dtype=dtype)
    else:
        #PIL is better at handling crops memory efficiently than cv2
        img = Image.open(image_path)
        #ensure that it's in RGB (otherwise there's a single channel for greyscale)
        img = np.array(img.crop(crop).convert('RGB'), dtype=dtype)
        #subset labels to area of interest
        #crop is (left, upper, right, lower)
        #https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.crop
        #upper:lower, left:right
        labels_sparse = labels_sparse[crop[1]:crop[3], crop[0]:crop[2]]
    #optionally normalise image
    if stardist_normalize:
        img = b2c.normalize(img)
        #actually cap the values - currently there are sub 0 and above 1 entries
        img[img<0] = 0
        img[img>1] = 1
        #turn back to uint8 for internal consistency
        img = (255*img).astype(np.uint8)
    #turn labels to COO for ease of position retrieval
    labels_sparse = labels_sparse.tocoo()
    #identify overlap between label objects and cdata observations
    #which should be the same nomenclature for the morphology segmentation
    #just as strings, while the labels are ints
    common_objects = np.sort(list(set(np.unique(labels_sparse.data)).intersection(set([int(i) for i in cdata.obs_names]))))
    #kick out filtered out objects from segmentation results
    labels_sparse.data[~np.isin(labels_sparse.data, common_objects)] = 0
    labels_sparse.eliminate_zeros()
    #catch legends to dictionary for returning later
    legends = {}
    #do a fill if requested
    if fill_key is not None:
        #legend comes in the form of a figure
        img, fig = overlay_onto_img(img=img, 
                                    labels_sparse=labels_sparse, 
                                    cdata=cdata, 
                                    key=fill_key, 
                                    common_objects=common_objects, 
                                    fill_label_weight=fill_label_weight
                                   )
        legends[fill_key] = fig
    #do a border if requested
    if border_key is not None:
        #actually get the border
        #unfortunately the boundary finder wants a dense matrix, so turn our labels to it for a sec
        #go for inner borders as that's all we care about, and that's less pixels to worry about
        #keep the border dense because of implementation reasons. it spikes RAM anyway when it's made
        border = skimage.segmentation.find_boundaries(np.array(labels_sparse.todense()), mode="inner")
        #whether we thicken or not, we need nonzero coordinates
        coords = np.nonzero(border)
        if thicken_border:
            #we're thickening. skimage.segmentation.expand_labels() explodes when asked to do this
            #the following gets the job done quicker and with lower RAM footprint
            #take existing nonzero coordinates and move them to the left, right, up and down by 1
            border_rows = np.hstack([
                np.clip(coords[0]-1, a_min=0, a_max=None),
                np.clip(coords[0]+1, a_min=None, a_max=border.shape[0]-1),
                coords[0],
                coords[0]
            ])
            border_cols = np.hstack([
                coords[1],
                coords[1],
                np.clip(coords[1]-1, a_min=0, a_max=None),
                np.clip(coords[1]+1, a_min=None, a_max=border.shape[1]-1)
            ])
            #set the positions to True. this is BLAZING FAST compared to sparse attempts
            #or, surprisingly, trying to np.unique() to get non-duplicate coordinates
            #and the entire reason we kept the borders dense for this process
            border[border_rows, border_cols] = True
            #update our nonzero coordinates
            coords = np.nonzero(border)
        #to assign borders back to objects, subset the object labels to just the border pixels
        #technically need to construct a new COO matrix for it, pulling out values at the border coordinates
        #use (data, (row, col)) constructor
        #also need to turn labels to CSR to be able to pull out their values
        #which results in a 2d numpy matrix, so turn to 1D array or constructor errors
        labels_sparse = scipy.sparse.coo_matrix((np.array(labels_sparse.tocsr()[coords[0], coords[1]]).flatten(), coords), shape=labels_sparse.shape)
        #there will be zeroes from the thickener, border mode="inner" means no thickener no zeros
        labels_sparse.eliminate_zeros()
        #can now run the overlayer, set weights to 1 to have fully opaque borders
        #legend comes in the form of a figure
        img, fig = overlay_onto_img(img=img, 
                                    labels_sparse=labels_sparse, 
                                    cdata=cdata, 
                                    key=border_key, 
                                    common_objects=common_objects, 
                                    fill_label_weight=1
                                   )
        legends[border_key] = fig
    return img, legends

The function of relevance to you is cell_label_render(). You need to feed it the image_path and labels_npz_path from your morphology segmentation, and the cell-level cdata that you make at the end of the basic bin2cell notebook. The .obs_names must be unchanged. You can then specify fill_key and/or border_key to be any float/int/categorical in .obs, or a gene name present in .var_names. The output is the resulting image, and a legends dictionary which has matplotlib figures for the legends of the various keys you passed (with the key value being how you pull it out of the dictionary).

A usage example on the cell object made in the mouse brain tutorial, ran after the end of the notebook. The crop is not necessary in normal full scale image use, but is specified in this illustrative example to get colours to line up with what's already familiar:

crop = b2c.get_crop(ddata, basis="spatial", spatial_key="spatial_cropped_150_buffer", mpp=mpp)
#this is a string at this time
cdata.obs["labels_joint_source"] = cdata.obs["labels_joint_source"].astype("category")

img, legends = cell_label_render(image_path="stardist/he.tiff",
                                 labels_npz_path="stardist/he.npz",
                                 cdata=cdata,
                                 crop=crop,
                                 fill_key="bin_count",
                                 border_key="labels_joint_source"
                                )

#looking at output, the legends just need to be pulled out from the dictionary and they display
plt.imshow(img)
legends["bin_count"]
legends["labels_joint_source"]

#saving output; for the image, you can do whatever really - it's a numpy array representation
plt.imsave("img.png", img)
legends["bin_count"].savefig("bin_count.png", bbox_inches="tight")
legends["labels_joint_source"].savefig("labels_joint_source.png", bbox_inches="tight")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants