Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error(s) in loading state_dict for trVAE #55

Open
LukasHats opened this issue Nov 3, 2024 · 4 comments
Open

Error(s) in loading state_dict for trVAE #55

LukasHats opened this issue Nov 3, 2024 · 4 comments
Labels
bug Something isn't working

Comments

@LukasHats
Copy link

LukasHats commented Nov 3, 2024

Report

Dear @marcovarrone ,

as suggested by you in the other issue I wanted to use trVAE for batch correction. However I get an Error when trying to load the saved trVAE. Here is my code.
Its data from IMC

import squidpy as sq
import cellcharter as cc
import scanpy as sc
import anndata as ad
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import scarches as sca
from lightning.pytorch import seed_everything
seed_everything(42)

# This data has the arcsinh-transformed values already in adata.X
adata = ad.read_h5ad("/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/results/standard/adatas/cells_annotated_pp_osteocytes_cleaned.h5ad")
adata.X = adata.X.astype(np.float32)
condition_key = 'patient_ID'
cell_type_key = 'Phenotype3'
conditions = adata.obs[condition_key].unique().tolist()

trvae_epochs = 500
surgery_epochs = 500

early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}
trvae = sca.models.TRVAE(
    adata=adata,
    condition_key=condition_key,
    conditions=conditions,
    recon_loss='mse',
    use_mmd=False,
)

trvae.train(
    n_epochs=trvae_epochs,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs,
    enable_progress_bar=True,
    
)

Preparing (1018580, 33)
Instantiating dataset
 |██████--------------| 34.4%  - val_loss: 9.5704051598 - val_recon_loss: 6.6211908229 - val_kl_loss: 3.4493733938
ADJUSTED LR
 |████████------------| 41.4%  - val_loss: 9.3446679523 - val_recon_loss: 6.2700993637 - val_kl_loss: 3.0745685972
ADJUSTED LR
 |█████████-----------| 49.4%  - val_loss: 9.1646852505 - val_recon_loss: 5.9590985781 - val_kl_loss: 3.2055866670
ADJUSTED LR
 |██████████----------| 53.6%  - val_loss: 9.1119472034 - val_recon_loss: 5.9010069772 - val_kl_loss: 3.2109402164
ADJUSTED LR
 |███████████---------| 56.8%  - val_loss: 9.1155722609 - val_recon_loss: 5.9224601750 - val_kl_loss: 3.1931120732
ADJUSTED LR
 |███████████---------| 58.2%  - val_loss: 9.1159391152 - val_recon_loss: 5.9208545379 - val_kl_loss: 3.1950845703
Stopping early: no improvement of more than 0 nats in 20 epochs
If the early stopping criterion is too strong, please instantiate it with different parameters in the train method.
Saving best state of network...
Best State was in Epoch 269

trvae.save('/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/', overwrite=True)

So Because of this issue I had to train with use_mmd=False
Now I tried to load the model, but I get this Error:

sc.pp.scale(adata)

model = cc.tl.TRVAE.load(
    '/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter', 
    adata, 
    map_location='cpu'
)

AnnData object with n_obs × n_vars = 1018580 × 33
    obs: 'Object', 'area', 'Y_centroid', 'X_centroid', 'axis_major_length', 'axis_minor_length', 'eccentricity', 'distance_to_bone', 'Phenotype', 'image_ID', 'disease', 'patient_ID', 'ROI', 'disease2', 'distance_to_bone_corrected', 'Phenotype2', 'Phenotype3'
    var: 'name', 'channel', 'deepcell', 'mean', 'std'
    layers: 'arcsinh', 'zscore'

INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 33 256 77
	Hidden Layer 1 in/out: 256 64
	Mean/Var Layer in/out: 64 10
Decoder Architecture:
	First Layer in, out and cond:  10 64 77
	Hidden Layer 1 in/out: 64 256
	Output Layer in/out:  256 33 


---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[13], [line 1](vscode-notebook-cell:?execution_count=13&line=1)
----> [1](vscode-notebook-cell:?execution_count=13&line=1) model = cc.tl.TRVAE.load(
      [2](vscode-notebook-cell:?execution_count=13&line=2)     '/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter', 
      [3](vscode-notebook-cell:?execution_count=13&line=3)     adata, 
      [4](vscode-notebook-cell:?execution_count=13&line=4)     map_location='cpu'
      [5](vscode-notebook-cell:?execution_count=13&line=5) )

File ~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/cellcharter/tl/_trvae.py:181, in TRVAE.load(cls, dir_path, adata, map_location)
    [178](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/cellcharter/tl/_trvae.py:178) init_params = cls._get_init_params_from_dict(attr_dict)
    [180](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/cellcharter/tl/_trvae.py:180) model = cls(adata, **init_params)
--> [181](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/cellcharter/tl/_trvae.py:181) model.model.load_state_dict(model_state_dict)
    [182](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/cellcharter/tl/_trvae.py:182) model.model.eval()
    [184](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/cellcharter/tl/_trvae.py:184) model.is_trained_ = attr_dict["is_trained_"]

File ~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:2584, in Module.load_state_dict(self, state_dict, strict, assign)
   [2576](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:2576)         error_msgs.insert(
   [2577](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:2577)             0,
   [2578](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:2578)             "Missing key(s) in state_dict: {}. ".format(
   [2579](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:2579)                 ", ".join(f'"{k}"' for k in missing_keys)
   [2580](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:2580)             ),
   [2581](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:2581)         )
   [2583](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:2583) if len(error_msgs) > 0:
-> [2584](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:2584)     raise RuntimeError(
   [2585](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:2585)         "Error(s) in loading state_dict for {}:\n\t{}".format(
   [2586](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:2586)             self.__class__.__name__, "\n\t".join(error_msgs)
   [2587](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:2587)         )
   [2588](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:2588)     )
   [2589](https://file+.vscode-resource.vscode-cdn.net/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/advanced_neighborhood/cellcharter/~/miniforge3/envs/cellcharter_scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:2589) return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for trVAE:
	Missing key(s) in state_dict: "decoder.recon_decoder.weight", "decoder.recon_decoder.bias". 
	Unexpected key(s) in state_dict: "decoder.recon_decoder.0.weight", "decoder.recon_decoder.0.bias".
  

Any ideas on this? why do we have this additional .0.?

I could circumvent this by directly using the autok object to transfer the latent adata object, however, I trained the trvAE on normalized (non-scaled) data and wanted to use it on the scaled data, as suggested by you. I trained it only on normalized data as trvae suggests that it should be trained on normalized data when mse is used. And I want to make sure I am not doing anything wrong here. Or should I also train it on scaled data?

Thanks a lot (again) in advance! :)

Version information

# Name                    Version                   Build  Channel
absl-py                   2.1.0                    pypi_0    pypi
affine                    2.4.0                    pypi_0    pypi
aiobotocore               2.5.4                    pypi_0    pypi
aiohappyeyeballs          2.4.3                    pypi_0    pypi
aiohttp                   3.10.10                  pypi_0    pypi
aioitertools              0.12.0                   pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
anndata                   0.10.8                   pypi_0    pypi
appnope                   0.1.4              pyhd8ed1ab_0    conda-forge
array-api-compat          1.9.1                    pypi_0    pypi
asciitree                 0.3.3                    pypi_0    pypi
asttokens                 2.4.1              pyhd8ed1ab_0    conda-forge
async-timeout             4.0.3                    pypi_0    pypi
attrs                     24.2.0                   pypi_0    pypi
beautifulsoup4            4.12.3                   pypi_0    pypi
botocore                  1.31.17                  pypi_0    pypi
brotli-python             1.1.0            py39hfa9831e_2    conda-forge
bzip2                     1.0.8                h99b78c6_7    conda-forge
ca-certificates           2024.8.30            hf0a4a13_0    conda-forge
cellcharter               0.3.1                    pypi_0    pypi
certifi                   2024.8.30          pyhd8ed1ab_0    conda-forge
cffi                      1.17.1           py39h7f933ea_0    conda-forge
charset-normalizer        3.4.0              pyhd8ed1ab_0    conda-forge
chex                      0.1.87                   pypi_0    pypi
click                     8.1.7                    pypi_0    pypi
click-plugins             1.1.1                    pypi_0    pypi
cligj                     0.7.2                    pypi_0    pypi
cloudpickle               3.1.0                    pypi_0    pypi
colorcet                  3.1.0                    pypi_0    pypi
comm                      0.2.2              pyhd8ed1ab_0    conda-forge
contextlib2               21.6.0                   pypi_0    pypi
contourpy                 1.3.0                    pypi_0    pypi
cpython                   3.9.20           py39hd8ed1ab_1    conda-forge
cycler                    0.12.1                   pypi_0    pypi
dask                      2024.8.0                 pypi_0    pypi
dask-expr                 1.1.10                   pypi_0    pypi
dask-image                2024.5.3                 pypi_0    pypi
datashader                0.16.3                   pypi_0    pypi
debugpy                   1.8.7            py39hfa9831e_0    conda-forge
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
distributed               2024.8.0                 pypi_0    pypi
docrep                    0.3.2                    pypi_0    pypi
etils                     1.5.2                    pypi_0    pypi
exceptiongroup            1.2.2              pyhd8ed1ab_0    conda-forge
executing                 2.1.0              pyhd8ed1ab_0    conda-forge
fasteners                 0.19                     pypi_0    pypi
filelock                  3.16.1             pyhd8ed1ab_0    conda-forge
flax                      0.8.5                    pypi_0    pypi
fonttools                 4.54.1                   pypi_0    pypi
freetype                  2.12.1               hadb7bae_2    conda-forge
frozenlist                1.5.0                    pypi_0    pypi
fsspec                    2023.6.0                 pypi_0    pypi
gdown                     5.2.0                    pypi_0    pypi
geopandas                 1.0.1                    pypi_0    pypi
get-annotations           0.1.2                    pypi_0    pypi
giflib                    5.2.2                h93a5062_0    conda-forge
gmp                       6.3.0                h7bae524_2    conda-forge
gmpy2                     2.1.5            py39h0bbb021_2    conda-forge
h2                        4.1.0              pyhd8ed1ab_0    conda-forge
h5py                      3.12.1                   pypi_0    pypi
hpack                     4.0.0              pyh9f0ad1d_0    conda-forge
humanize                  4.11.0                   pypi_0    pypi
hyperframe                6.0.1              pyhd8ed1ab_0    conda-forge
idna                      3.10               pyhd8ed1ab_0    conda-forge
igraph                    0.11.8                   pypi_0    pypi
imageio                   2.36.0                   pypi_0    pypi
importlib-metadata        8.5.0              pyha770c72_0    conda-forge
importlib-resources       6.4.5                    pypi_0    pypi
inflect                   7.4.0                    pypi_0    pypi
ipykernel                 6.29.5             pyh57ce528_0    conda-forge
ipython                   8.18.1             pyh707e725_3    conda-forge
jax                       0.4.30                   pypi_0    pypi
jaxlib                    0.4.30                   pypi_0    pypi
jedi                      0.19.1             pyhd8ed1ab_0    conda-forge
jinja2                    3.1.4              pyhd8ed1ab_0    conda-forge
jmespath                  1.0.1                    pypi_0    pypi
joblib                    1.4.2                    pypi_0    pypi
jupyter_client            8.6.3              pyhd8ed1ab_0    conda-forge
jupyter_core              5.7.2              pyh31011fe_1    conda-forge
kiwisolver                1.4.7                    pypi_0    pypi
krb5                      1.21.3               h237132a_0    conda-forge
lazy-loader               0.4                      pypi_0    pypi
lcms2                     2.16                 ha0e7c42_0    conda-forge
legacy-api-wrap           1.4                      pypi_0    pypi
leidenalg                 0.10.2                   pypi_0    pypi
lerc                      4.0.0                h9a09cb3_0    conda-forge
libblas                   3.9.0           19_osxarm64_openblas    conda-forge
libcblas                  3.9.0           19_osxarm64_openblas    conda-forge
libcxx                    19.1.3               ha82da77_0    conda-forge
libdeflate                1.22                 hd74edd7_0    conda-forge
libedit                   3.1.20191231         hc8eb9b7_2    conda-forge
libffi                    3.4.2                h3422bc3_5    conda-forge
libgfortran               5.0.0           13_2_0_hd922786_3    conda-forge
libgfortran5              13.2.0               hf226fd6_3    conda-forge
libjpeg-turbo             3.0.0                hb547adb_1    conda-forge
liblapack                 3.9.0           19_osxarm64_openblas    conda-forge
libopenblas               0.3.24          openmp_hd76b1f2_0    conda-forge
libpng                    1.6.44               hc14010f_0    conda-forge
libsodium                 1.0.20               h99b78c6_0    conda-forge
libsqlite                 3.47.0               hbaaea75_1    conda-forge
libtiff                   4.7.0                hfce79cd_1    conda-forge
libwebp                   1.4.0                h54798ee_0    conda-forge
libwebp-base              1.4.0                h93a5062_0    conda-forge
libxcb                    1.17.0               hdb1d25a_0    conda-forge
libzlib                   1.3.1                h8359307_2    conda-forge
lightning                 2.4.0                    pypi_0    pypi
lightning-utilities       0.11.8                   pypi_0    pypi
llvm-openmp               15.0.7               h7cfbb63_0    conda-forge
llvmlite                  0.43.0                   pypi_0    pypi
locket                    1.0.0                    pypi_0    pypi
markdown-it-py            3.0.0                    pypi_0    pypi
markupsafe                3.0.2            py39h66d85bf_0    conda-forge
matplotlib                3.8.4                    pypi_0    pypi
matplotlib-inline         0.1.7              pyhd8ed1ab_0    conda-forge
matplotlib-scalebar       0.8.1                    pypi_0    pypi
mdurl                     0.1.2                    pypi_0    pypi
ml-collections            0.1.1                    pypi_0    pypi
ml-dtypes                 0.5.0                    pypi_0    pypi
more-itertools            10.5.0                   pypi_0    pypi
mpc                       1.3.1                h8f1351a_1    conda-forge
mpfr                      4.2.1                hb693164_3    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
msgpack                   1.1.0                    pypi_0    pypi
mudata                    0.2.4                    pypi_0    pypi
multidict                 6.1.0                    pypi_0    pypi
multipledispatch          1.0.0                    pypi_0    pypi
multiscale-spatial-image  1.0.1                    pypi_0    pypi
muon                      0.1.6                    pypi_0    pypi
natsort                   8.4.0                    pypi_0    pypi
ncurses                   6.5                  h7bae524_1    conda-forge
nest-asyncio              1.6.0              pyhd8ed1ab_0    conda-forge
networkx                  3.2.1              pyhd8ed1ab_0    conda-forge
newick                    1.0.0                    pypi_0    pypi
numba                     0.60.0                   pypi_0    pypi
numcodecs                 0.12.1                   pypi_0    pypi
numpy                     1.26.4                   pypi_0    pypi
numpyro                   0.15.3                   pypi_0    pypi
ome-zarr                  0.9.0                    pypi_0    pypi
omnipath                  1.0.8                    pypi_0    pypi
openjpeg                  2.5.2                h9f1df11_0    conda-forge
openssl                   3.3.2                h8359307_0    conda-forge
opt-einsum                3.4.0                    pypi_0    pypi
optax                     0.2.3                    pypi_0    pypi
orbax-checkpoint          0.6.4                    pypi_0    pypi
packaging                 24.1               pyhd8ed1ab_0    conda-forge
pandas                    2.2.3                    pypi_0    pypi
param                     2.1.1                    pypi_0    pypi
parso                     0.8.4              pyhd8ed1ab_0    conda-forge
partd                     1.4.2                    pypi_0    pypi
patsy                     0.5.6                    pypi_0    pypi
pexpect                   4.9.0              pyhd8ed1ab_0    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    11.0.0           py39h4ac03e3_0    conda-forge
pims                      0.7                      pypi_0    pypi
pip                       24.3.1             pyh8b19718_0    conda-forge
platformdirs              4.3.6              pyhd8ed1ab_0    conda-forge
pooch                     1.8.2                    pypi_0    pypi
prompt-toolkit            3.0.48             pyha770c72_0    conda-forge
propcache                 0.2.0                    pypi_0    pypi
protobuf                  5.28.3                   pypi_0    pypi
psutil                    6.1.0            py39h57695bc_0    conda-forge
pthread-stubs             0.4               hd74edd7_1002    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pure_eval                 0.2.3              pyhd8ed1ab_0    conda-forge
pyarrow                   18.0.0                   pypi_0    pypi
pycparser                 2.22               pyhd8ed1ab_0    conda-forge
pyct                      0.5.0                    pypi_0    pypi
pygments                  2.18.0             pyhd8ed1ab_0    conda-forge
pynndescent               0.5.13                   pypi_0    pypi
pyogrio                   0.10.0                   pypi_0    pypi
pyparsing                 3.2.0                    pypi_0    pypi
pyproj                    3.6.1                    pypi_0    pypi
pyro-api                  0.1.2                    pypi_0    pypi
pyro-ppl                  1.9.1                    pypi_0    pypi
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
python                    3.9.20          h9e33284_1_cpython    conda-forge
python-dateutil           2.9.0.post0              pypi_0    pypi
python_abi                3.9                      5_cp39    conda-forge
pytorch                   2.5.1                   py3.9_0    pytorch
pytorch-lightning         2.4.0                    pypi_0    pypi
pytz                      2024.2                   pypi_0    pypi
pyyaml                    6.0.2            py39h06df861_1    conda-forge
pyzmq                     26.2.0           py39h6e893d0_3    conda-forge
rasterio                  1.4.2                    pypi_0    pypi
readline                  8.2                  h92ec313_1    conda-forge
requests                  2.32.3             pyhd8ed1ab_0    conda-forge
rich                      13.9.4                   pypi_0    pypi
s3fs                      2023.6.0                 pypi_0    pypi
scanpy                    1.10.3                   pypi_0    pypi
scarches                  0.6.1                    pypi_0    pypi
schpl                     1.0.5                    pypi_0    pypi
scikit-image              0.24.0                   pypi_0    pypi
scikit-learn              1.5.2                    pypi_0    pypi
scipy                     1.13.1                   pypi_0    pypi
scvi-tools                1.1.6.post2              pypi_0    pypi
seaborn                   0.13.2                   pypi_0    pypi
session-info              1.0.0                    pypi_0    pypi
setuptools                75.3.0             pyhd8ed1ab_0    conda-forge
shapely                   2.0.6                    pypi_0    pypi
six                       1.16.0             pyh6c4a22f_0    conda-forge
sknw                      0.15                     pypi_0    pypi
slicerator                1.1.0                    pypi_0    pypi
sortedcontainers          2.4.0                    pypi_0    pypi
soupsieve                 2.6                      pypi_0    pypi
spatial-image             1.1.0                    pypi_0    pypi
spatialdata               0.2.3                    pypi_0    pypi
spatialdata-plot          0.2.7                    pypi_0    pypi
squidpy                   1.6.1                    pypi_0    pypi
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
statsmodels               0.14.4                   pypi_0    pypi
stdlib-list               0.11.0                   pypi_0    pypi
sympy                     1.13.1                   pypi_0    pypi
tblib                     3.0.0                    pypi_0    pypi
tensorstore               0.1.67                   pypi_0    pypi
texttable                 1.7.0                    pypi_0    pypi
threadpoolctl             3.5.0                    pypi_0    pypi
tifffile                  2024.8.30                pypi_0    pypi
tk                        8.6.13               h5083fa2_1    conda-forge
toolz                     1.0.0                    pypi_0    pypi
torchaudio                2.5.1                  py39_cpu    pytorch
torchgmm                  0.1.2                    pypi_0    pypi
torchmetrics              1.5.1                    pypi_0    pypi
torchvision               0.20.1                 py39_cpu    pytorch
tornado                   6.4.1            py39h06df861_1    conda-forge
tqdm                      4.66.6                   pypi_0    pypi
traitlets                 5.14.3             pyhd8ed1ab_0    conda-forge
typeguard                 4.4.0                    pypi_0    pypi
typing_extensions         4.12.2             pyha770c72_0    conda-forge
tzdata                    2024.2                   pypi_0    pypi
umap-learn                0.5.7                    pypi_0    pypi
urllib3                   1.26.20                  pypi_0    pypi
validators                0.34.0                   pypi_0    pypi
wcwidth                   0.2.13             pyhd8ed1ab_0    conda-forge
wheel                     0.44.0             pyhd8ed1ab_0    conda-forge
wrapt                     1.16.0                   pypi_0    pypi
xarray                    2024.7.0                 pypi_0    pypi
xarray-dataclasses        1.8.0                    pypi_0    pypi
xarray-datatree           0.0.15                   pypi_0    pypi
xarray-schema             0.0.3                    pypi_0    pypi
xarray-spatial            0.4.0                    pypi_0    pypi
xorg-libxau               1.0.11               hd74edd7_1    conda-forge
xorg-libxdmcp             1.1.5                hd74edd7_0    conda-forge
xz                        5.2.6                h57fd34a_0    conda-forge
yaml                      0.2.5                h3422bc3_2    conda-forge
yarl                      1.17.1                   pypi_0    pypi
zarr                      2.18.2                   pypi_0    pypi
zeromq                    4.3.5                h9f5b81c_6    conda-forge
zict                      3.0.0                    pypi_0    pypi
zipp                      3.20.2             pyhd8ed1ab_0    conda-forge
zstandard                 0.23.0           py39hcf1bb16_1    conda-forge
zstd                      1.5.6                hb46c0d2_0    conda-forge
@LukasHats LukasHats added the bug Something isn't working label Nov 3, 2024
@marcovarrone
Copy link
Collaborator

Hi @LukasHats, this is because CellCharter uses a slightly modified version of trVAE (I just removed the last ReLU layer).
So you always have to use CellCharter's implementation for training.
This means running:

`trvae = cc.tl.TRVAE(
adata=adata,
condition_key=condition_key,
conditions=conditions,
recon_loss='mse',
use_mmd=False,
)

trvae.train(
n_epochs=trvae_epochs,
alpha_epoch_anneal=200,
early_stopping_kwargs=early_stopping_kwargs,
enable_progress_bar=True,

)`

I will add trVAE's implementation to CellCharter's documentation and clarify this aspect!
Thank you for pointing it out :)

@LukasHats
Copy link
Author

Ah that makes sense! Thanks for the quick solution.
2 more questions:

  1. trVAE documentation states that mse should be used on (log1p) normalized data, so I used that or in my case arcsinh. Would you say training the cc.tl.TRVAE model should be done with that or the scaled data as you suggested?

  2. If I would like to discover functional marker-based neighborhoods, i.e. metabolic markers without having cell type markers interfering with it, would it theoretically work to drop all markers but the metabolic markers from the adata and then run cellcharter? I guess the cluster stability might be problematic without strong celltype determining markers.

@marcovarrone
Copy link
Collaborator

marcovarrone commented Nov 4, 2024

  1. log or arcsinh normalization transforms the data to a distribution closer to the Gaussian distribution, so it can be combined with scaling. They are not mutually exclusive concepts!
  2. Very interesting question. Yes, the approach of keeping only the metabolic markers is the first thing that came to mind. The other option would be in trVAE to use the cell type as a covariate, but I think it risks distorting the embeddings in unpredictable ways.
    Regarding the stability I honestly cannot answer without running it. There could still be "pockets" of stable solutions based on metabolic markers only, but I cannot guarantee anything :)

EDIT: if you do both scaling and normalization, I would normalize first and scale after, not the reverse!

@LukasHats
Copy link
Author

Then I will try with and without scaling. Thanks for your input! I will leave this issue open so you can close it after adding the documentation :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants