Skip to content

Commit

Permalink
Merge branch 'yuya/neva_llama3' into siglip_merge_llama3
Browse files Browse the repository at this point in the history
Signed-off-by: HuiyingLi <[email protected]>
  • Loading branch information
HuiyingLi committed May 13, 2024
2 parents b605b6e + 24785dd commit 39134ce
Show file tree
Hide file tree
Showing 14 changed files with 745 additions and 221 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ inference:
compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False
end_strings: ["<extra_id_1>","<extra_id_7>",] # generation will stop when one of these tokens is generated
media_base_path: /pwd/images # /path/to/images or /path/to/videos
insert_media_token: left # `left` or `right` or `null`
insert_media_token: null # `left` or `right` or `null`
media_type: image # `image` or `video`

trainer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from omegaconf import OmegaConf

from nemo.collections.multimodal.parts.utils import create_neva_model_and_processor
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam

CFG_STRING = """
trainer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def eval_model(args):
parser.add_argument("--image-folder", type=str, default="")
parser.add_argument("--question-file", type=str, default="tables/question.json")
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
parser.add_argument("--conv-mode", type=str, default="llava_v0")
parser.add_argument("--conv-mode", type=str, default="llava_v0") # this flag has no use!
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--num-chunks", type=int, default=1)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from argparse import ArgumentParser
from omegaconf.omegaconf import OmegaConf

from nemo.collections.multimodal.models.multimodal_llm.neva.neva_model import MegatronNevaModel
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.utils import logging


def get_args():
parser = ArgumentParser()
parser.add_argument(
"--input_path",
type=str,
default=None,
required=True,
help="Path to NeMo legacy checkpoints",
)
parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.")
parser.add_argument("--gpus_per_node", type=int, required=False, default=8)
parser.add_argument("--num_nodes", type=int, required=False, default=1)
parser.add_argument(
"--precision",
type=str,
required=False,
default='bf16-mixed',
choices=['32-true', '16-mixed', 'bf16-mixed'],
help="Precision value for the trainer that matches with precision of the ckpt",
)
args = parser.parse_args()
return args


def main() -> None:
args = get_args()
cfg = {
'trainer': {
'devices': args.gpus_per_node,
'num_nodes': args.num_nodes,
'accelerator': 'gpu',
'precision': args.precision,
},
'model': {
'native_amp_init_scale': 2**32,
'native_amp_growth_interval': 1000,
'hysteresis': 2,
'gradient_as_bucket_view': True,
},
'cluster_type': 'BCP',
}
cfg = OmegaConf.create(cfg)

# Set precision None after precision plugins are created as PTL >= 2.1 does not allow both
# precision plugins and precision to exist
cfg.trainer.precision = None
trainer = MegatronTrainerBuilder(cfg).create_trainer()

save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(args.input_path):
save_restore_connector.model_extracted_dir = args.input_path

model = MegatronNevaModel.restore_from(
restore_path=args.input_path,
trainer=trainer,
save_restore_connector=save_restore_connector,
strict=False,
)

model.save_to(args.output_path)
logging.info(f'NeMo model saved to: {args.output_path}')


if __name__ == '__main__':
main()
61 changes: 56 additions & 5 deletions nemo/collections/multimodal/data/neva/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from collections import defaultdict
from enum import Enum, auto
from typing import List

Expand All @@ -24,9 +25,14 @@
DEFAULT_SYSTEM_TOKEN = "<extra_id_0>"
DEFAULT_SEPARATOR_TOKEN = "<extra_id_1>"
DEFAULT_LABELS_TOKEN = "<extra_id_2>"
DEFAULT_IMAGE_PATCH_TOKEN = "<extra_id_3>"
DEFAULT_IM_START_TOKEN = "<extra_id_4>"
DEFAULT_IM_END_TOKEN = "<extra_id_5>"
DEFAULT_IMAGE_PATCH_TOKEN = defaultdict(lambda: "<extra_id_3>")
DEFAULT_IM_START_TOKEN = defaultdict(lambda: "<extra_id_4>")
DEFAULT_IM_END_TOKEN = defaultdict(lambda: "<extra_id_5>")

# Update llama3 default
DEFAULT_IMAGE_PATCH_TOKEN["llama_3"] = "<|reserved_special_token_3|>"
DEFAULT_IM_START_TOKEN["llama_3"] = "<|reserved_special_token_4|>"
DEFAULT_IM_END_TOKEN["llama_3"] = "<|reserved_special_token_5|>"


class SeparatorStyle(Enum):
Expand All @@ -36,6 +42,7 @@ class SeparatorStyle(Enum):
TWO = auto()
PLAIN = auto()
LLAMA_2 = auto()
LLAMA_3 = auto()
NVGPT = auto()


Expand Down Expand Up @@ -109,6 +116,34 @@ def get_prompt(self):
else:
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.LLAMA_3:
"""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|>
{{ user_message_1 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{{ model_answer_1 }}<|eot_id|><|start_header_id|>user<|end_header_id|>
{{ user_message_2 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
wrap_sys = lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}"
wrap_user = lambda msg: f"<|start_header_id|>user<|end_header_id|>\n\n{msg}"
wrap_assistant = lambda msg: f"<|start_header_id|>assistant<|end_header_id|>\n\n{msg}"

ret = "<|begin_of_text|>" + wrap_sys(self.system) + self.sep
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if type(message) is tuple:
message, _, _ = message
elif i % 2 == 0:
ret += wrap_user(message) + self.sep
else:
ret += wrap_assistant(message) + (self.sep if message else "")

elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system
Expand Down Expand Up @@ -346,8 +381,25 @@ def dict(self):
sep2=DEFAULT_EOS_TOKEN,
)

conv_llava_llama_3 = Conversation(
system="You are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
"and assist the user with a variety of tasks using natural language.",
roles=("user", "assistant"),
version="llama_v3",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_3,
sep="<|eot_id|>",
)

conv_llava_plain = Conversation(
system="", roles=("", ""), messages=(), offset=0, sep_style=SeparatorStyle.PLAIN, sep="\n",
system="",
roles=("", ""),
messages=(),
offset=0,
sep_style=SeparatorStyle.PLAIN,
sep="\n",
)

conv_llava_v0 = Conversation(
Expand Down Expand Up @@ -416,6 +468,5 @@ def dict(self):
"nv_dpo": conv_nv_dpo,
}


if __name__ == "__main__":
print(default_conversation.get_prompt())
Loading

0 comments on commit 39134ce

Please sign in to comment.