Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support quantization with adapter v1 and v2 finetuning #694

Merged
merged 7 commits into from
Jan 19, 2024

Conversation

safurrier
Copy link
Contributor

@safurrier safurrier commented Nov 2, 2023

Closes #392

Implements quantization for adapter and adapter v2 with same code as used for LoRa.

Ran it on StableLM and Llama 7B for adapter and adapter_v2:

Model Adapter Version Settings Training Memory Training Time Inference Memory
StableLM 3B v1 Default (bf16-mixed) 25.52 GB 0.83 min 7.34 GB
StableLM 3B v1 --precision bf16-true 9.12 GB 0.68 min 7.34 GB
StableLM 3B v1 --precision bf16-true --quantize bnb.nf4 8.23 GB 1.61 min 8.23 GB
StableLM 3B v1 --precision bf16-true --quantize bnb.nf4-dq 8.23 GB 1.63 min 8.23 GB
Llama 2 7B v1 Default (bf16-mixed) OutOfMemoryError N/A 13.58 GB
Llama 2 7B v1 --precision bf16-true 21.30 GB 1.61 min 13.58 GB
Llama 2 7B v1 --precision bf16-true --quantize bnb.nf4 14.08 GB 3.08 min 14.08 GB
Llama 2 7B v1 --precision bf16-true --quantize bnb.nf4-dq 14.08 GB 3.17 min 14.08 GB
StableLM 3B v2 Default (bf16-mixed) 30.05 GB 1.00 min 7.34 GB
StableLM 3B v2 --precision bf16-true 10.73 GB 0.81 min 7.34 GB
StableLM 3B v2 --precision bf16-true --quantize bnb.nf4 8.23 GB 1.7 min 8.23 GB
StableLM 3B v2 --precision bf16-true --quantize bnb.nf4-dq 8.23 GB 1.74 min 8.23 GB
Llama 2 7B v2 Default (bf16-mixed) OutOfMemoryError N/A 13.59 GB
Llama 2 7B v2 --precision bf16-true 26.91 GB 2.12 min 13.59 GB
Llama 2 7B v2 --precision bf16-true --quantize bnb.nf4 19.72 GB 3.38 min 14.09 GB
Llama 2 7B v2 --precision bf16-true --quantize bnb.nf4-dq 19.39 GB 3.47 min 14.09 GB

There is some duplicated code now across the lora, adapter and adapter_v2 for setting up quantization (mainly around setting up plugins based on quantization flag and also selecting BnB compatible optimizer). That could be cleaned up with some common utils but didn't want to refactor to that level unless it was desired.

@safurrier
Copy link
Contributor Author

safurrier commented Nov 2, 2023

The most time consuming part was running the finetunes sweep across the 2 models, 2 finetune versions and 4 precision+quantization levels.

I wrote a Makefile that was able to run that all. In case something like that is useful for others will leave it here:

# Params
########

# Adapter V1 or V2
# V1 leave blank, V2 is _v2
ADAPTER_VERSION=
ADAPTER_VERSION=_v2
ADAPTER_FINETUNE_CMD=python finetune/adapter$(ADAPTER_VERSION).py
ADAPTER_GENERATE_CMD=python generate/adapter$(ADAPTER_VERSION).py --prompt "Recommend a movie to watch on the weekend."

# Change model here
MODEL_SOURCE=stabilityai
MODEL_NAME=$(MODEL_SOURCE)/stablelm-base-alpha-3b

# MODEL_SOURCE=meta-llama
# MODEL_NAME=$(MODEL_SOURCE)/Llama-2-7b-chat-hf

run-all-adapters: setup
	@echo "Running adapter with stabilityai (v1)..."
	@mkdir -p logs/stabilityai
	@make adapter ADAPTER_VERSION= MODEL_SOURCE=stabilityai MODEL_NAME=stabilityai/stablelm-base-alpha-3b

	@echo "Running adapter with stabilityai (_v2)..."
	@mkdir -p logs/stabilityai
	@make adapter ADAPTER_VERSION=_v2 MODEL_SOURCE=stabilityai MODEL_NAME=stabilityai/stablelm-base-alpha-3b


	@echo "Running adapter with meta-llama (v1)..."
	@mkdir -p logs/meta-llama
	@make adapter ADAPTER_VERSION= MODEL_SOURCE=meta-llama MODEL_NAME=meta-llama/Llama-2-7b-chat-hf

	@echo "Running adapter with meta-llama (_v2)..."
	@mkdir -p logs/meta-llama
	@make adapter ADAPTER_VERSION=_v2 MODEL_SOURCE=meta-llama MODEL_NAME=meta-llama/Llama-2-7b-chat-hf


CHECKPOINT_DIR=checkpoints/$(MODEL_NAME)
LOG_SUFFIX=.txt

# Setup
#######

# Install Python dependencies
requirements:
	pip install huggingface_hub sentencepiece
	pip install -r requirements-all.txt


# Download Models
download-stable-3b: requirements
	python scripts/download.py --repo_id stabilityai/stablelm-base-alpha-3b
	python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/stabilityai/stablelm-base-alpha-3b
	python scripts/prepare_alpaca.py

download-llama-7b: requirements
	python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf --access_token XXX
	python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf
	python scripts/prepare_alpaca.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf

# Prerequisites check
######################

# Check CUDA availability
cuda-check:
	@echo "Checking for CUDA..."
	@/usr/bin/env nvidia-smi > /dev/null 2>&1 || (echo "CUDA is not available or nvidia-smi is not in your PATH"; exit 1)

logs:
	mkdir logs
	mkdir logs/$(MODEL_SOURCE)

prereqs: logs cuda-check

setup: download-stable-3b download-llama-7b

# Train Adapter
finetune-adapter-default: prereqs
	$(ADAPTER_FINETUNE_CMD) \
	--checkpoint_dir $(CHECKPOINT_DIR) \
	--out_dir out/adapter$(ADAPTER_VERSION)/$(MODEL_NAME)/default \
	2>&1 | tee logs/$(MODEL_NAME)-finetune-adapter$(ADAPTER_VERSION)-default$(LOG_SUFFIX)

finetune-adapter-bf16: prereqs
	$(ADAPTER_FINETUNE_CMD) \
	--precision bf16-true \
	--checkpoint_dir $(CHECKPOINT_DIR) \
	--out_dir out/adapter$(ADAPTER_VERSION)/$(MODEL_NAME)/bf16 \
	2>&1 | tee logs/$(MODEL_NAME)-finetune-adapter$(ADAPTER_VERSION)-bf16$(LOG_SUFFIX)

finetune-adapter-bf16-bnb.nf4: prereqs
	$(ADAPTER_FINETUNE_CMD) \
	--precision bf16-true \
	--quantize "bnb.nf4" \
	--checkpoint_dir $(CHECKPOINT_DIR) \
	--out_dir out/adapter$(ADAPTER_VERSION)/$(MODEL_NAME)/bf16-bnb-nf4 \
	2>&1 | tee logs/$(MODEL_NAME)-finetune-adapter$(ADAPTER_VERSION)-bf16-bnb.nf4$(LOG_SUFFIX)

finetune-adapter-bf16-bnb.nf4-dq: prereqs
	$(ADAPTER_FINETUNE_CMD) \
	--precision bf16-true \
	--quantize "bnb.nf4-dq" \
	--checkpoint_dir $(CHECKPOINT_DIR) \
	--out_dir out/adapter$(ADAPTER_VERSION)/$(MODEL_NAME)/bf16-bnb-nf4-dq \
	 2>&1 | tee logs/$(MODEL_NAME)-finetune-adapter$(ADAPTER_VERSION)-bf16-bnb.nf4-dq$(LOG_SUFFIX)

finetune-all: finetune-adapter-default finetune-adapter-bf16 finetune-adapter-bf16-bnb.nf4 finetune-adapter-bf16-bnb.nf4-dq

# Generate (check inference memory)
generate-adapter-default: prereqs
	$(ADAPTER_GENERATE_CMD) \
	--checkpoint_dir $(CHECKPOINT_DIR) \
	--adapter_path out/adapter$(ADAPTER_VERSION)/$(MODEL_NAME)/default/lit_model_adapter_finetuned.pth \
	2>&1 | tee logs/$(MODEL_NAME)-generate-adapter$(ADAPTER_VERSION)-default$(LOG_SUFFIX)

generate-adapter-bf16:
	$(ADAPTER_GENERATE_CMD) \
	--precision bf16-true \
	--checkpoint_dir $(CHECKPOINT_DIR) \
	--adapter_path out/adapter$(ADAPTER_VERSION)/$(MODEL_NAME)/bf16/lit_model_adapter_finetuned.pth \
	2>&1 | tee logs/$(MODEL_NAME)-generate-adapter$(ADAPTER_VERSION)-bf16$(LOG_SUFFIX)

generate-adapter-bf16-bnb.nf4: prereqs
	$(ADAPTER_GENERATE_CMD) \
	--precision bf16-true \
	--quantize "bnb.nf4" \
	--checkpoint_dir $(CHECKPOINT_DIR) \
	--adapter_path out/adapter$(ADAPTER_VERSION)/$(MODEL_NAME)/bf16-bnb-nf4/lit_model_adapter_finetuned.pth \
	2>&1 | tee logs/$(MODEL_NAME)-generate-adapter$(ADAPTER_VERSION)-bf16-bnb.nf4$(LOG_SUFFIX)

generate-adapter-bf16-bnb.nf4-dq: prereqs
	$(ADAPTER_GENERATE_CMD) \
	--precision bf16-true \
	--quantize "bnb.nf4-dq" \
	--checkpoint_dir $(CHECKPOINT_DIR) \
	--adapter_path out/adapter$(ADAPTER_VERSION)/$(MODEL_NAME)/bf16-bnb-nf4-dq/lit_model_adapter_finetuned.pth \
	2>&1 | tee logs/$(MODEL_NAME)-generate-adapter$(ADAPTER_VERSION)-bf16-bnb.nf4-dq$(LOG_SUFFIX)

# Finetune and generate combined
adapter-default: finetune-adapter-default generate-adapter-default
adapter-bf16: finetune-adapter-bf16 generate-adapter-bf16
adapter-bf16-bnb.nf4: finetune-adapter-bf16-bnb.nf4 generate-adapter-bf16-bnb.nf4
adapter-bf16-bnb.nf4-dq: finetune-adapter-bf16-bnb.nf4-dq generate-adapter-bf16-bnb.nf4-dq
adapter: adapter-default adapter-bf16 adapter-bf16-bnb.nf4 adapter-bf16-bnb.nf4-dq

finetune/adapter.py Outdated Show resolved Hide resolved
@Andrei-Aksionov
Copy link
Collaborator

Andrei-Aksionov commented Nov 3, 2023

Hey @safurrier
Thanks for the PR!
And additionally thanks for the Makefile example 👍.


BitsandbytesPrecision plugin works here without any hiccups, as the trainable adapter weights are stored as nn.Embeddings and the plugin replaces nn.Linear. FIY: that plugin replaces all the linear layers with custom linear implementation that quantizes weights when they are moved to CUDA device.

As a sanity check, I took a look at dtypes for each layer when quantization is applied with command for adapter_v1:

python finetune/adapter.py --checkpoint_dir checkpoints/EleutherAI/pythia-70m --quantize bnb.nf4 --precision 16-true
_forward_module.lm_head.weight ----------------------------- torch.uint8
_forward_module.transformer.wte.weight --------------------- torch.float16
...
_forward_module.transformer.h.2.norm_1.weight -------------- torch.float16
_forward_module.transformer.h.2.norm_1.bias ---------------- torch.float16
_forward_module.transformer.h.2.attn.gating_factor --------- torch.float16
_forward_module.transformer.h.2.attn.attn.weight ----------- torch.uint8
_forward_module.transformer.h.2.attn.attn.bias ------------- torch.float16
_forward_module.transformer.h.2.attn.proj.weight ----------- torch.uint8
_forward_module.transformer.h.2.attn.proj.bias ------------- torch.float16
_forward_module.transformer.h.2.attn.adapter_wte.weight ---- torch.float16
_forward_module.transformer.h.2.norm_2.weight -------------- torch.float16
_forward_module.transformer.h.2.norm_2.bias ---------------- torch.float16
_forward_module.transformer.h.2.mlp.fc.weight -------------- torch.uint8
_forward_module.transformer.h.2.mlp.fc.bias ---------------- torch.float16
_forward_module.transformer.h.2.mlp.proj.weight ------------ torch.uint8
_forward_module.transformer.h.2.mlp.proj.bias -------------- torch.float16
...
_forward_module.transformer.ln_f.weight -------------------- torch.float16
_forward_module.transformer.ln_f.bias ---------------------- torch.float16

and for adapter_v2:

quantize-adapters ~/repos/temp/lit-gpt python finetune/adapter_v2.py --checkpoint_dir checkpoints/EleutherAI/pythia-70m --quantize bnb.nf4 --precision 16-true
_forward_module.lm_head.adapter_bias ----------------------- torch.float16
_forward_module.lm_head.adapter_scale ---------------------- torch.float16
_forward_module.lm_head.linear.weight ---------------------- torch.uint8
_forward_module.transformer.wte.weight --------------------- torch.float16
...
_forward_module.transformer.h.2.norm_1.weight -------------- torch.float16
_forward_module.transformer.h.2.norm_1.bias ---------------- torch.float16
_forward_module.transformer.h.2.attn.gating_factor --------- torch.float16
_forward_module.transformer.h.2.attn.attn.adapter_bias ----- torch.float16
_forward_module.transformer.h.2.attn.attn.adapter_scale ---- torch.float16
_forward_module.transformer.h.2.attn.attn.linear.weight ---- torch.uint8
_forward_module.transformer.h.2.attn.attn.linear.bias ------ torch.float16
_forward_module.transformer.h.2.attn.proj.adapter_bias ----- torch.float16
_forward_module.transformer.h.2.attn.proj.adapter_scale ---- torch.float16
_forward_module.transformer.h.2.attn.proj.linear.weight ---- torch.uint8
_forward_module.transformer.h.2.attn.proj.linear.bias ------ torch.float16
_forward_module.transformer.h.2.attn.adapter_wte.weight ---- torch.float16
_forward_module.transformer.h.2.norm_2.weight -------------- torch.float16
_forward_module.transformer.h.2.norm_2.bias ---------------- torch.float16
_forward_module.transformer.h.2.mlp.fc.adapter_bias -------- torch.float16
_forward_module.transformer.h.2.mlp.fc.adapter_scale ------- torch.float16
_forward_module.transformer.h.2.mlp.fc.linear.weight ------- torch.uint8
_forward_module.transformer.h.2.mlp.fc.linear.bias --------- torch.float16
_forward_module.transformer.h.2.mlp.proj.adapter_bias ------ torch.float16
_forward_module.transformer.h.2.mlp.proj.adapter_scale ----- torch.float16
_forward_module.transformer.h.2.mlp.proj.linear.weight ----- torch.uint8
_forward_module.transformer.h.2.mlp.proj.linear.bias ------- torch.float16
...
_forward_module.transformer.ln_f.weight -------------------- torch.float16
_forward_module.transformer.ln_f.bias ---------------------- torch.float16

As you can see, in both cases only weight matrices for attention (QKV and projection), mlp and lm_head are quantized.
Other parameters, including trainable adapter's adapter_wte.weight is kept in non-quantized form.


To display dtypes one can put this code snippet right after fabric.setup:

for name, layer in model.named_parameters():
    print(f"{(name + ' ').ljust(60, '-')} {layer.dtype}")

@safurrier
Copy link
Contributor Author

@Andrei-Aksionov good to know. I didn't dig too deeply into things and was hoping this would mainly work out of the box.

Is the quantization being limited only to the new params the desired behavior? Or for LoRa is the entire model quantized as well?

@Andrei-Aksionov
Copy link
Collaborator

Andrei-Aksionov commented Nov 9, 2023

... was hoping this would mainly work out of the box.

Basically that what you can expect when dealing with Lightning products 😄. Maybe sounds a bit cheesy, but from my experience it's true.

Is the quantization being limited only to the new params the desired behavior? Or for LoRa is the entire model quantized as well?

Yes, this is exactly the desired behavior.
With QLoRA only the pretrained weights are quantized. Trainable LoRA weights are kept in non-quantized form.
During each forward pass the pretrained weights are dequantized, matmuls for pretrained and LoRA weights are done, the results are summed.

Screenshot 2023-11-09 at 11 42 57 AM

With adapter it's a bit different, but the idea is the same.

@carmocca carmocca changed the title add quantization to adapter and adapter v2 finetune Support quantization with adapter v1 and v2 finetuning Jan 19, 2024
Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I added tests and included the results directly in the resource tables

@carmocca carmocca merged commit 0f021f3 into Lightning-AI:main Jan 19, 2024
8 of 9 checks passed
rasbt pushed a commit that referenced this pull request Mar 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4-Bit quantization for Adapter and Adapter v2 methods
3 participants