-
Notifications
You must be signed in to change notification settings - Fork 0
/
census_scvi.py
187 lines (156 loc) · 5.82 KB
/
census_scvi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import logging
import cellxgene_census
import click
import lightning.pytorch as pl
import scvi
import somacore
import torch
import torch.distributed as dist
import torchdata
from cellxgene_census.experimental.ml import ExperimentDataPipe
from lightning.pytorch.callbacks import DeviceStatsMonitor
from scvi import REGISTRY_KEYS
from scvi.model import SCVI
from torch.utils.data import DataLoader
from cellxgene_census.experimental.ml.pytorch import pytorch_logger
scvi.settings.seed = 0
N_GENES = 60664
logger = logging.getLogger("census_scvi")
logger.setLevel(logging.INFO)
class CensusDataLoader(DataLoader):
def __init__(self, datapipe: ExperimentDataPipe, *args, **kwargs):
super().__init__(datapipe, *args, **kwargs)
pytorch_logger.setLevel(logging.DEBUG)
pytorch_logger.debug(f"pytorch dist rank={dist.get_rank()}, data shape={datapipe.shape}")
def __iter__(self):
for tensors in super().__iter__():
x, _ = tensors
x = x.float() # avoid "RuntimeError: mat1 and mat2 must have the same dtype", due to 32-bit vs 64-bit floats
# print(x.shape)
yield {
REGISTRY_KEYS.X_KEY: x,
REGISTRY_KEYS.BATCH_KEY: torch.zeros((x.shape[0], 1)),
REGISTRY_KEYS.LABELS_KEY: None,
REGISTRY_KEYS.CONT_COVS_KEY: None,
REGISTRY_KEYS.CAT_COVS_KEY: None,
}
class CensusSCVI(SCVI):
def __init__(
self,
datapipe: torchdata.datapipes.iter.IterDataPipe,
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 1,
dropout_rate: float = 0.1,
dispersion: str = "gene",
gene_likelihood: str = "zinb",
latent_distribution: str = "normal",
**model_kwargs,
):
self.module = self._module_cls(
N_GENES,
n_hidden=n_hidden,
n_latent=n_latent,
n_layers=n_layers,
dropout_rate=dropout_rate,
dispersion=dispersion,
gene_likelihood=gene_likelihood,
latent_distribution=latent_distribution,
**model_kwargs,
)
self.datapipe = datapipe
self.is_trained_ = False
self._model_summary_string = ""
self.train_indices_ = None
self.test_indices_ = None
self.validation_indices_ = None
self.history_ = None
def train(
self,
max_epochs: int = None,
use_gpu: bool = None,
accelerator: str = "auto",
devices: int = "auto",
plan_kwargs: dict = None,
**trainer_kwargs,
):
plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else {}
training_plan = self._training_plan_cls(self.module, **plan_kwargs)
datamodule = CensusDataModule(self.datapipe)
print(trainer_kwargs)
runner = self._train_runner_cls(
self,
training_plan=training_plan,
data_splitter=datamodule,
max_epochs=max_epochs,
use_gpu=use_gpu,
accelerator=accelerator,
devices=devices,
**trainer_kwargs
)
return runner()
class CensusDataModule(pl.LightningDataModule):
def __init__(self, datapipe):
self.datapipe = datapipe
self.train_idx = None
self.val_idx = None
self.test_idx = None
super().__init__()
def setup(self, stage=None):
pass
def train_dataloader(self):
return CensusDataLoader(self.datapipe)
def val_dataloader(self):
pass
def test_dataloader(self):
pass
@click.option("--census-uri", default=None, help="URI to census tiledb")
@click.option("--organism", default="homo_sapiens", help="Organism to use")
@click.option("--measurement-name", default="RNA")
@click.option("--layer-name", default="raw", help="Layer name to use")
@click.option("--obs-value-filter", default=None, type=str, help="Obs value filter to use")
@click.option("--torch-batch-size", default=128)
@click.option("--soma-buffer-bytes", type=int)
@click.option("--use-eager-fetch/--no-use-eager-fetch", default=True)
@click.option("--torch-devices", type=str, default=None)
@click.option("--max-epochs", default=1)
@click.command()
def main(census_uri,
organism,
measurement_name,
layer_name,
obs_value_filter,
torch_batch_size,
soma_buffer_bytes,
use_eager_fetch,
torch_devices,
max_epochs
) -> None:
pytorch_logger.setLevel(logging.DEBUG)
census = cellxgene_census.open_soma(uri=census_uri) if census_uri else cellxgene_census.open_soma()
dp = ExperimentDataPipe(
census["census_data"][organism],
measurement_name=measurement_name,
X_name=layer_name,
obs_query=somacore.AxisQuery(value_filter=obs_value_filter),
batch_size=int(torch_batch_size),
soma_buffer_bytes=soma_buffer_bytes,
use_eager_fetch=use_eager_fetch
)
print(f"training data shape={dp.shape}")
# for b, batch in enumerate(dp):
# if b % 1000 == 0:
# print(f"processed {b} batches")
# sys.exit(0)
shuffle_dp = dp # .shuffle()
model = CensusSCVI(shuffle_dp)
model.train(max_epochs=int(max_epochs), accelerator="gpu" if torch_devices else "cpu",
devices=torch_devices if torch_devices else 1, strategy="ddp_find_unused_parameters_true",
profiler="simple", callbacks=[DeviceStatsMonitor()],
# for iterable datasets
# see https://pytorch-lightning.readthedocs.io/en/1.7.7/guides/data.html#iterable-datasets and
# https://lightning.ai/docs/pytorch/stable/common/trainer.html#val-check-interval
# val_check_interval=100, check_val_every_n_epoch=None,
)
if __name__ == "__main__":
main()