Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: yaoyu-33 <[email protected]>
  • Loading branch information
yaoyu-33 committed Jul 8, 2024
1 parent 188686c commit d3e728e
Show file tree
Hide file tree
Showing 10 changed files with 179 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def main(cfg) -> None:
n, c, h = cfg.model.micro_batch_size, cfg.model.channels, cfg.model.image_size
x = torch.randn((n, c, h, h), dtype=torch.float32, device="cuda")
t = torch.randint(77, (n,), device="cuda")
cc = torch.randn((n, 77, cfg.model.unet_config.context_dim), dtype=torch.float32, device="cuda",)
cc = torch.randn(
(n, 77, cfg.model.unet_config.context_dim),
dtype=torch.float32,
device="cuda",
)
if cfg.model.precision in [16, '16']:
x = x.type(torch.float16)
cc = cc.type(torch.float16)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def _training_strategy(self) -> NLPDDPStrategy:
_IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)
if _IS_INTERACTIVE and self.cfg.trainer.devices == 1:
logging.info("Detected interactive environment, using NLPDDPStrategyNotebook")
return NLPDDPStrategyNotebook(no_ddp_communication_hook=True, find_unused_parameters=False,)
return NLPDDPStrategyNotebook(
no_ddp_communication_hook=True,
find_unused_parameters=False,
)

if self.cfg.model.get('fsdp', False):
assert (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def __init__(self, cfg, model_parallel_config):
self._init_first_stage(first_stage_config)
self.model_type = None

self.rng = torch.Generator(device=torch.cuda.current_device(),)
self.rng = torch.Generator(
device=torch.cuda.current_device(),
)

self.use_ema = False # TODO use_ema need to switch to NeMo style
if self.use_ema:
Expand Down Expand Up @@ -192,7 +194,12 @@ def training_step(self, batch, batch_idx):
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False)

self.log(
"global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False,
"global_step",
self.global_step,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
)

if self.scheduler_config is not None:
Expand Down Expand Up @@ -238,7 +245,11 @@ def configure_optimizers(self):
scheduler = DiffusionEngine.from_config_dict(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1,}
{
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
}
]
return [opt], scheduler
return opt
Expand Down Expand Up @@ -298,7 +309,14 @@ def set_input_tensor(self, input_tensor):
pass

@torch.no_grad()
def log_images(self, batch: Dict, N: int = 8, sample: bool = True, ucg_keys: List[str] = None, **kwargs,) -> Dict:
def log_images(
self,
batch: Dict,
N: int = 8,
sample: bool = True,
ucg_keys: List[str] = None,
**kwargs,
) -> Dict:
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
if ucg_keys:
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
Expand All @@ -312,7 +330,8 @@ def log_images(self, batch: Dict, N: int = 8, sample: bool = True, ucg_keys: Lis
x = self.get_input(batch)

c, uc = self.conditioner.get_unconditional_conditioning(
batch, force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [],
batch,
force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [],
)

sampling_kwargs = {}
Expand Down Expand Up @@ -407,7 +426,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only):
# handle asynchronous grad reduction
no_sync_func = None
if not forward_only and self.with_distributed_adam:
no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,)
no_sync_func = partial(
self._optimizer.no_sync,
greedy_grad_copy=self.megatron_amp_O2,
)

# pipeline schedules will get these from self.model.config
for module in self.get_module_list():
Expand Down Expand Up @@ -445,12 +467,12 @@ def fwd_bwd_step(self, dataloader_iter, forward_only):

def training_step(self, dataloader_iter):
"""
Our dataloaders produce a micro-batch and then we fetch
a number of microbatches depending on the global batch size and model parallel size
from the dataloader to produce a list of microbatches.
Batch should be a list of microbatches and those microbatches should on CPU.
Microbatches are then moved to GPU during the pipeline.
The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions.
Our dataloaders produce a micro-batch and then we fetch
a number of microbatches depending on the global batch size and model parallel size
from the dataloader to produce a list of microbatches.
Batch should be a list of microbatches and those microbatches should on CPU.
Microbatches are then moved to GPU during the pipeline.
The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions.
"""
self._optimizer.zero_grad()

Expand Down Expand Up @@ -498,20 +520,20 @@ def training_step(self, dataloader_iter):
return loss_mean

def backward(self, *args, **kwargs):
""" LightningModule hook to do backward.
We want this to do nothing since we run backward in the fwd/bwd functions from apex.
No need to call it here.
"""LightningModule hook to do backward.
We want this to do nothing since we run backward in the fwd/bwd functions from apex.
No need to call it here.
"""
pass

def optimizer_zero_grad(self, *args, **kwargs):
""" LightningModule hook to zero grad.
We want this to do nothing as we are zeroing grads during the training_step.
"""LightningModule hook to zero grad.
We want this to do nothing as we are zeroing grads during the training_step.
"""
pass

def _append_sequence_parallel_module_grads(self, module, grads):
""" Helper method for allreduce_sequence_parallel_gradients"""
"""Helper method for allreduce_sequence_parallel_gradients"""

for param in module.parameters():
sequence_parallel_param = getattr(param, 'sequence_parallel', False)
Expand All @@ -524,12 +546,13 @@ def _append_sequence_parallel_module_grads(self, module, grads):

def get_forward_output_and_loss_func(self):
def process_batch(batch):
""" Prepares the global batch for apex fwd/bwd functions.
Global batch is a list of micro batches.
"""Prepares the global batch for apex fwd/bwd functions.
Global batch is a list of micro batches.
"""
# SD has more dedicated structure for encoding, so we enable autocasting here as well
with torch.cuda.amp.autocast(
self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype,
self.autocast_dtype in (torch.half, torch.bfloat16),
dtype=self.autocast_dtype,
):
if self.model.precache_mode == 'both':
x = batch[self.model.input_key].to(torch.cuda.current_device())
Expand Down Expand Up @@ -572,7 +595,7 @@ def validation_step(self, dataloader_iter, batch_idx):
return loss

def setup(self, stage=None):
""" PTL hook that is executed after DDP spawns.
"""PTL hook that is executed after DDP spawns.
We setup datasets here as megatron datasets require DDP to instantiate.
See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information.
Args:
Expand Down Expand Up @@ -685,20 +708,23 @@ def setup_test_data(self, cfg):
f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}'
)
self._test_dl = torch.utils.data.DataLoader(
self._test_ds, batch_size=self._micro_batch_size, num_workers=cfg.num_workers, pin_memory=True,
self._test_ds,
batch_size=self._micro_batch_size,
num_workers=cfg.num_workers,
pin_memory=True,
)

def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
""" PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device
When using pipeline parallelism, we need the global batch to remain on the CPU,
since the memory overhead will be too high when using a large number of microbatches.
Microbatches are transferred from CPU to GPU inside the pipeline.
"""PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device
When using pipeline parallelism, we need the global batch to remain on the CPU,
since the memory overhead will be too high when using a large number of microbatches.
Microbatches are transferred from CPU to GPU inside the pipeline.
"""
return batch

def _validate_trainer(self):
""" Certain trainer configurations can break training.
Here we try to catch them and raise an error.
"""Certain trainer configurations can break training.
Here we try to catch them and raise an error.
"""
if self.trainer.accumulate_grad_batches > 1:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def __init__(

def _state_key_mapping(self, state_dict: dict):
import re

res_dict = {}
key_list = state_dict.keys()
key_str = " ".join(key_list)
Expand Down Expand Up @@ -397,7 +398,7 @@ def _state_key_mapping(self, state_dict: dict):
res_dict[key_] = val_
return res_dict

def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from_NeMo = False):
def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from_NeMo=False):
if not from_NeMo:
state_dict = self._state_key_mapping(state_dict)
model_state_dict = self.state_dict()
Expand All @@ -408,7 +409,10 @@ def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from
unexpected_keys = list(set(loaded_keys) - set(expected_keys))

def _find_mismatched_keys(
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes,
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
Expand Down Expand Up @@ -443,7 +447,10 @@ def _find_mismatched_keys(
if state_dict is not None:
# Whole checkpoint
mismatched_keys = _find_mismatched_keys(
state_dict, model_state_dict, original_loaded_keys, ignore_mismatched_sizes,
state_dict,
model_state_dict,
original_loaded_keys,
ignore_mismatched_sizes,
)
error_msgs = self._load_state_dict_into_model(state_dict)
return missing_keys, unexpected_keys, mismatched_keys, error_msgs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,7 @@ def __init__(
logging.info(f"Attempting to load pretrained unet from {from_pretrained}")
if from_pretrained.endswith('safetensors'):
from safetensors.torch import load_file as load_safetensors

state_dict = load_safetensors(from_pretrained)
else:
state_dict = torch.load(from_pretrained, map_location='cpu')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def __init__(
):
self.num_steps = num_steps
self.discretization = instantiate_from_config(discretization_config)
self.guider = instantiate_from_config(default(guider_config, DEFAULT_GUIDER,))
self.guider = instantiate_from_config(
default(
guider_config,
DEFAULT_GUIDER,
)
)
self.verbose = verbose
self.device = device

Expand Down Expand Up @@ -103,22 +108,22 @@ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0,
sigma_hat = sigma * (gamma + 1.0)
if gamma > 0:
eps = torch.randn_like(x) * self.s_noise
x = x + eps * append_dims(sigma_hat ** 2 - sigma ** 2, x.ndim) ** 0.5
x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5

denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
# this is the noise (e_t)
d = to_d(x, sigma_hat, denoised)
dt = append_dims(next_sigma - sigma_hat, x.ndim)

euler_step = self.euler_step(x, d, dt) # this is x_{t-\delta{t}}
euler_step = self.euler_step(x, d, dt) # this is x_{t-\delta{t}}
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
if return_noise:
return x, d
return x

def get_gamma(self, sigmas, num_sigmas, index):
gamma = (
min(self.s_churn / (num_sigmas - 1), 2 ** 0.5 - 1) if self.s_tmin <= sigmas[index] <= self.s_tmax else 0.0
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[index] <= self.s_tmax else 0.0
)
return gamma

Expand All @@ -128,9 +133,17 @@ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):

for i in self.get_sigma_gen(num_sigmas):
gamma = self.get_gamma(sigmas, num_sigmas, i)
x = self.sampler_step(s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc, gamma,)
x = self.sampler_step(
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
cond,
uc,
gamma,
)
return x


class AncestralSampler(SingleStepDiffusionSampler):
def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
Expand Down Expand Up @@ -158,14 +171,24 @@ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)

for i in self.get_sigma_gen(num_sigmas):
x = self.sampler_step(s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc,)
x = self.sampler_step(
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
cond,
uc,
)

return x


class LinearMultistepSampler(BaseDiffusionSampler):
def __init__(
self, order=4, *args, **kwargs,
self,
order=4,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -283,7 +306,15 @@ def get_mult(self, h, r, t, t_next, previous_sigma):
return mult1, mult2

def sampler_step(
self, old_denoised, previous_sigma, sigma, next_sigma, denoiser, x, cond, uc=None,
self,
old_denoised,
previous_sigma,
sigma,
next_sigma,
denoiser,
x,
cond,
uc=None,
):
denoised = self.denoise(x, denoiser, sigma, cond, uc)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,9 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs) -> torch.
x = torch.cat((x, c.get("concat")), dim=1)

return self.diffusion_model(
x, timesteps=t, context=c.get("crossattn", None), y=c.get("vector", None), **kwargs,
x,
timesteps=t,
context=c.get("crossattn", None),
y=c.get("vector", None),
**kwargs,
)
Loading

0 comments on commit d3e728e

Please sign in to comment.