Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepspeed inference #3241

Open
Reginald-L opened this issue Nov 17, 2024 · 0 comments
Open

deepspeed inference #3241

Reginald-L opened this issue Nov 17, 2024 · 0 comments

Comments

@Reginald-L
Copy link

Reginald-L commented Nov 17, 2024

Hi, I am using deepspeed zero3 to fine tune flux model using kohya scripts.

I want to sample images every 50 steps during training phase, I try to use deepspeed-inference, here is my code

from deepspeed.runtime.zero.config import DeepSpeedZeroOffloadParamConfig
zero_offload_config = DeepSpeedZeroOffloadParamConfig()
zero_offload_config.device = 'cpu'
zero_offload_config.pin_memory =True
# deepspeed inference
infer_engine = deepspeed.init_inference(
    model, 
    dtype = torch.bfloat16,
    zero = {
        "stage": 3,
        "offload_param": zero_offload_config
    }
)

# make sure all data are in the current GPU
img = img.contiguous()
img_ids = img_ids.contiguous()
txt = txt.contiguous()
txt_ids = txt_ids.contiguous()
vec = vec.contiguous()

for t_curr, t_prev in zip(tqdm(timesteps[:-1]) if accelerator.is_main_process else timesteps[:-1], timesteps[1:]):

    t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) 
    pred = infer_engine(
        img=img,
        img_ids=img_ids,
        txt=txt,
        txt_ids=txt_ids,
        y=vec,
        timesteps=t_vec,
        guidance=guidance_vec,
        txt_attention_mask=t5_attn_mask,
    )
    torch.cuda.synchronize()
    if torch.distributed.is_initialized():
        torch.distributed.barrier(async_op=True)

    img = img + (t_prev - t_curr) * pred

but I still get OOM error
image

there is no any nvlink in my computers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant