From 53cb640636c4314fd388925e586018b17580e555 Mon Sep 17 00:00:00 2001 From: kui huang <82303451+pkhk-1@users.noreply.github.com> Date: Mon, 1 Apr 2024 10:32:25 +0800 Subject: [PATCH] fix sft gradient (#494) Co-authored-by: LokeZhou --- paddlemix/config/llava/pretrain.json | 2 ++ paddlemix/examples/llava/run_predict_multiround.py | 2 +- paddlemix/trainer/llava_trainer.py | 3 --- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/paddlemix/config/llava/pretrain.json b/paddlemix/config/llava/pretrain.json index 284e8626c..ced6bcdc5 100644 --- a/paddlemix/config/llava/pretrain.json +++ b/paddlemix/config/llava/pretrain.json @@ -1,5 +1,7 @@ { "model_name_or_path": "paddlemix/llava/vicuna-13b-v1.5", + "freeze_include": ["*llama*", "*lm_head*"], + "freeze_exclude": ["*llama.mm_projector*"], "dataset": { "train":[{"name": "chatml_dataset", "data_files": "train.json"}], "eval":[{"name": "chatml_dataset", "data_files": "val.json"}] diff --git a/paddlemix/examples/llava/run_predict_multiround.py b/paddlemix/examples/llava/run_predict_multiround.py index 2553e15b3..79bd09d94 100644 --- a/paddlemix/examples/llava/run_predict_multiround.py +++ b/paddlemix/examples/llava/run_predict_multiround.py @@ -109,7 +109,7 @@ def main(args): with paddle.no_grad(): output_ids = model.generate( input_ids=data_dict["input_ids"], - images=data_dict["images"], + images=paddle.cast(data_dict["images"],compute_dtype), image_sizes=[image_size], decode_strategy="sampling" if args.temperature > 0 else "greedy_search", temperature=args.temperature, diff --git a/paddlemix/trainer/llava_trainer.py b/paddlemix/trainer/llava_trainer.py index 448f880ae..2864e7130 100644 --- a/paddlemix/trainer/llava_trainer.py +++ b/paddlemix/trainer/llava_trainer.py @@ -141,9 +141,6 @@ def create_optimizer(self, lr_scheduler=None): opt_model = self.model - # case pretrain - for p in self.model.parameters(): - p.stop_gradient = True for p in self.model.llama.mm_projector.parameters(): p.stop_gradient = not True