Skip to content

Commit

Permalink
* fix undefined device_map: using balanced in default or use to met…
Browse files Browse the repository at this point in the history
…hod to move models to specified devices

* fix unrecognized dtype: only need torch.dtype instead of strings like 'fp16'
* open unittest for image_diffusion_mapper
  • Loading branch information
HYLcool committed Dec 27, 2024
1 parent 13460e5 commit 4110a1a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
16 changes: 14 additions & 2 deletions data_juicer/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@
'ram_plus_swin_large_14m.pth',
}

TORCH_DTYPE_MAPPING = {
'fp32': torch.float32,
'fp16': torch.float16,
'bf16': torch.bfloat16,
}


def get_backup_model_link(model_name):
for pattern, url in BACKUP_MODEL_LINKS.items():
Expand Down Expand Up @@ -282,8 +288,12 @@ def prepare_diffusion_model(pretrained_model_name_or_path, diffusion_type,
"""
AUTOINSTALL.check(['torch', 'transformers'])

if 'device' in model_params:
model_params['device_map'] = model_params.pop('device')
device = model_params.pop('device', None)
if not device:
model_params['device_map'] = 'balanced'
if 'torch_dtype' in model_params:
model_params['torch_dtype'] = TORCH_DTYPE_MAPPING[
model_params['torch_dtype']]

diffusion_type_to_pipeline = {
'image2image': diffusers.AutoPipelineForImage2Image,
Expand All @@ -300,6 +310,8 @@ def prepare_diffusion_model(pretrained_model_name_or_path, diffusion_type,
pipeline = diffusion_type_to_pipeline[diffusion_type]
model = pipeline.from_pretrained(pretrained_model_name_or_path,
**model_params)
if device:
model = model.to(device)

return model

Expand Down
7 changes: 1 addition & 6 deletions tests/ops/mapper/test_image_diffusion_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,8 @@
from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.image_diffusion_mapper import ImageDiffusionMapper
from data_juicer.utils.mm_utils import SpecialTokens
from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
DataJuicerTestCaseBase)
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase


# Skip tests for this OP in the GitHub actions due to OOM on the current runner
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class ImageDiffusionMapperTest(DataJuicerTestCaseBase):

data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..',
Expand Down

0 comments on commit 4110a1a

Please sign in to comment.