Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multi-device support for DPT #31066

Closed
wants to merge 6 commits into from

Conversation

OmarManzoor
Copy link
Contributor

What does this PR do?

Adds multi device for for Dpt. I tested this on a kaggle notebook with two T4 gpus.
Screenshot 2024-05-27 at 8 11 03 PM

Towards #29786

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@amyeroberts @ArthurZucker

@OmarManzoor
Copy link
Contributor Author

I am not sure what happened but the same test that passed previously is now failing.

Screenshot 2024-05-28 at 12 46 24 PM

@LysandreJik
Copy link
Member

cc @SunMarc

@SunMarc
Copy link
Member

SunMarc commented May 28, 2024

Hi @OmarManzoor, not sure what it is happening. However, usually, to make it work on multi device, you need to put some modules in _no_split_modules. In this case , the test is not passing because the embedding size it too big for the first gpu (more than half the size of the model). In test_model_parallelism, you could try making it work with
new_model = model_class.from_pretrained(tmp_dir, device_map="sequential", max_memory=max_memory) , by changing device_map="auto" to "sequential" ? This should fix the first assert error. I will soon create a PR to do that since the test is not super well designed for unbalanced small models.

@OmarManzoor
Copy link
Contributor Author

In test_model_parallelism, you could try making it work with new_model = model_class.from_pretrained(tmp_dir, device_map="sequential", max_memory=max_memory) , by changing device_map="auto" to "sequential" ?

Hi @SunMarc,
If I do that won't it actually change this test generically for all models which is supposed to test for device_map="auto"? Or should I override this test only for this particular model?

@SunMarc
Copy link
Member

SunMarc commented May 29, 2024

Hi @SunMarc,
If I do that won't it actually change this test generically for all models which is supposed to test for device_map="auto"? Or should I override this test only for this particular model?

The plan is to switch to sequential in the future. I need to check on our CI which tests will fail and fix it.
Right now, for this test, there is no way to pass it if we go with device_map="auto" because of the unbalanced model. We can potentially skip the test but it would be great if we can still enable the multi-gpu nevertheless.

@OmarManzoor
Copy link
Contributor Author

Hi @SunMarc,
If I do that won't it actually change this test generically for all models which is supposed to test for device_map="auto"? Or should I override this test only for this particular model?

The plan is to switch to sequential in the future. I need to check on our CI which tests will fail and fix it. Right now, for this test, there is no way to pass it if we go with device_map="auto" because of the unbalanced model. We can potentially skip the test but it would be great if we can still enable the multi-gpu nevertheless.

Then maybe I should mark this particular test to be skipped for DPT?

@SunMarc
Copy link
Member

SunMarc commented May 29, 2024

Then maybe I should mark this particular test to be skipped for DPT?

Sure, but right now it doesn't work with multi gpu. When I changed to sequential, I was hitting a device mismatch issue. So that would be great to fix this before merging this !

@OmarManzoor
Copy link
Contributor Author

So I tried adding
_no_split_modules = ["DPTViTEmbeddings", "DPTViTSelfAttention"]
to DPT but now the other tests are failing as well and the test_model_parallelism returns the same error

def test_model_parallelism(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
    
        for model_class in self.all_model_classes:
            if model_class._no_split_modules is None:
                continue
    
            inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
            model = model_class(config).eval()
            model = model.to(torch_device)
    
            torch.manual_seed(0)
            base_output = model(**inputs_dict_class)
    
            model_size = compute_module_sizes(model)[""]
            # We test several splits of sizes to make sure it works.
            max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
            with tempfile.TemporaryDirectory() as tmp_dir:
                model.cpu().save_pretrained(tmp_dir)
    
                for max_size in max_gpu_sizes:
                    max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
                    new_model = model_class.from_pretrained(tmp_dir, device_map="sequential", max_memory=max_memory)
                    # Making sure part of the model will actually end up offloaded
>                   self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
E                   AssertionError: Items in the second set but not the first:
E                   0

tests/test_modeling_common.py:3161: AssertionError

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Jul 5, 2024
@OmarManzoor OmarManzoor deleted the multi_device_dpt branch July 12, 2024 15:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants