-
Notifications
You must be signed in to change notification settings - Fork 507
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ca364b3
commit d5586d3
Showing
7 changed files
with
185 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from argparse import ArgumentParser | ||
|
||
from . import BaseAutoTrainCommand | ||
|
||
|
||
def run_tools_command_factory(args): | ||
return RunAutoTrainToolsCommand(args) | ||
|
||
|
||
class RunAutoTrainToolsCommand(BaseAutoTrainCommand): | ||
@staticmethod | ||
def register_subcommand(parser: ArgumentParser): | ||
run_app_parser = parser.add_parser("tools", help="Run AutoTrain tools") | ||
subparsers = run_app_parser.add_subparsers(title="tools", dest="tool_name") | ||
|
||
merge_llm_parser = subparsers.add_parser( | ||
"merge-llm-adapter", | ||
help="Merge LLM Adapter tool", | ||
) | ||
merge_llm_parser.add_argument( | ||
"--base-model-path", | ||
type=str, | ||
help="Base model path", | ||
) | ||
merge_llm_parser.add_argument( | ||
"--adapter-path", | ||
type=str, | ||
help="Adapter path", | ||
) | ||
merge_llm_parser.add_argument( | ||
"--token", | ||
type=str, | ||
help="Token", | ||
default=None, | ||
required=False, | ||
) | ||
merge_llm_parser.add_argument( | ||
"--pad-to-multiple-of", | ||
type=int, | ||
help="Pad to multiple of", | ||
default=None, | ||
required=False, | ||
) | ||
merge_llm_parser.add_argument( | ||
"--output-folder", | ||
type=str, | ||
help="Output folder", | ||
required=False, | ||
default=None, | ||
) | ||
merge_llm_parser.add_argument( | ||
"--push-to-hub", | ||
action="store_true", | ||
help="Push to Hugging Face Hub", | ||
required=False, | ||
) | ||
merge_llm_parser.set_defaults(func=run_tools_command_factory, merge_llm_adapter=True) | ||
|
||
convert_to_kohya_parser = subparsers.add_parser("convert_to_kohya", help="Convert to Kohya tool") | ||
convert_to_kohya_parser.add_argument( | ||
"--input-path", | ||
type=str, | ||
help="Input path", | ||
) | ||
convert_to_kohya_parser.add_argument( | ||
"--output-path", | ||
type=str, | ||
help="Output path", | ||
) | ||
convert_to_kohya_parser.set_defaults(func=run_tools_command_factory, convert_to_kohya=True) | ||
|
||
def __init__(self, args): | ||
self.args = args | ||
|
||
def run(self): | ||
if getattr(self.args, "merge_llm_adapter", False): | ||
self.run_merge_llm_adapter() | ||
if getattr(self.args, "convert_to_kohya", False): | ||
self.run_convert_to_kohya() | ||
|
||
def run_merge_llm_adapter(self): | ||
from autotrain.tools.merge_adapter import merge_llm_adapter | ||
|
||
merge_llm_adapter( | ||
base_model_path=self.args.base_model_path, | ||
adapter_path=self.args.adapter_path, | ||
token=self.args.token, | ||
output_folder=self.args.output_folder, | ||
pad_to_multiple_of=self.args.pad_to_multiple_of, | ||
push_to_hub=self.args.push_to_hub, | ||
) | ||
|
||
def run_convert_to_kohya(self): | ||
from autotrain.tools.convert_to_kohya import convert_to_kohya | ||
|
||
convert_to_kohya( | ||
input_path=self.args.input_path, | ||
output_path=self.args.output_path, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya | ||
from safetensors.torch import load_file, save_file | ||
|
||
from autotrain import logger | ||
|
||
|
||
def convert_to_kohya(input_path, output_path): | ||
logger.info(f"Converting Lora state dict from {input_path} to Kohya state dict at {output_path}") | ||
lora_state_dict = load_file(input_path) | ||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) | ||
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict) | ||
save_file(kohya_state_dict, output_path) | ||
logger.info(f"Kohya state dict saved at {output_path}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import torch | ||
from peft import PeftModel | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from autotrain import logger | ||
from autotrain.trainers.common import ALLOW_REMOTE_CODE | ||
|
||
|
||
def merge_llm_adapter( | ||
base_model_path, adapter_path, token, output_folder=None, pad_to_multiple_of=None, push_to_hub=False | ||
): | ||
|
||
if output_folder is None and push_to_hub is False: | ||
raise ValueError("You must specify either --output_folder or --push_to_hub") | ||
|
||
logger.info("Loading adapter...") | ||
base_model = AutoModelForCausalLM.from_pretrained( | ||
base_model_path, | ||
torch_dtype=torch.float16, | ||
low_cpu_mem_usage=True, | ||
trust_remote_code=ALLOW_REMOTE_CODE, | ||
token=token, | ||
) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained( | ||
adapter_path, | ||
trust_remote_code=ALLOW_REMOTE_CODE, | ||
token=token, | ||
) | ||
if pad_to_multiple_of: | ||
base_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=pad_to_multiple_of) | ||
else: | ||
base_model.resize_token_embeddings(len(tokenizer)) | ||
|
||
model = PeftModel.from_pretrained( | ||
base_model, | ||
adapter_path, | ||
token=token, | ||
) | ||
model = model.merge_and_unload() | ||
|
||
if output_folder is not None: | ||
logger.info("Saving target model...") | ||
model.save_pretrained(output_folder) | ||
tokenizer.save_pretrained(output_folder) | ||
logger.info(f"Model saved to {output_folder}") | ||
|
||
if push_to_hub: | ||
logger.info("Pushing model to Hugging Face Hub...") | ||
model.push_to_hub(adapter_path) | ||
tokenizer.push_to_hub(adapter_path) | ||
logger.info(f"Model pushed to Hugging Face Hub as {adapter_path}") |