From ee0933ce993a8c5596bba1335c910f478ef5cf91 Mon Sep 17 00:00:00 2001 From: Bo Li Date: Thu, 16 May 2024 05:50:41 +0000 Subject: [PATCH] Fix config overwrite bug in train.py --- llava/train/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/llava/train/train.py b/llava/train/train.py index 00235af47..4aa12c3d8 100755 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -1224,6 +1224,7 @@ def get_model(model_args, training_args, bnb_model_from_pretrained_args): customized_kwargs.update(bnb_model_from_pretrained_args) overwrite_config = {} + cfg_pretrained = None if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None: cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path) overwrite_config["rope_scaling"] = { @@ -1246,6 +1247,9 @@ def get_model(model_args, training_args, bnb_model_from_pretrained_args): overwrite_config["mm_spatial_pool_mode"] = model_args.mm_spatial_pool_mode if overwrite_config: + if cfg_pretrained is None: + cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path) + rank0_print(f"Overwriting config with {overwrite_config}") for k, v in overwrite_config.items(): setattr(cfg_pretrained, k, v)