Skip to content

Commit

Permalink
merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
victorchall committed Mar 2, 2024
1 parent be2cec7 commit d098223
Show file tree
Hide file tree
Showing 3 changed files with 463 additions and 38 deletions.
194 changes: 156 additions & 38 deletions caption_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,84 @@
import io
import argparse
import time
from typing import Generator
import json
import logging
import re
from typing import TYPE_CHECKING, Generator, Optional, List, Tuple, Literal

import torch
from torchvision import transforms

from PIL import Image
import PIL.ImageOps as ImageOps
from pynvml import *

from transformers import AutoModelForCausalLM, LlamaTokenizer
from transformers import AutoModelForCausalLM, LlamaTokenizer, PreTrainedTokenizer
from colorama import Fore, Style

from plugins.caption_plugins import load_prompt_alteration_plugin

SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
IMAGE_SIZE: int = 490
PATCH_SIZE: int = 14

def build_conversation_input_ids(
tokenizer: PreTrainedTokenizer,
*,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
images: Optional[List[Image.Image]] = None,
starts_with: Optional[str] = None,
):
# based on https://huggingface.co/THUDM/cogvlm-chat-hf/blob/main/modeling_cogvlm.py
image_size: int = IMAGE_SIZE
patch_size: int = PATCH_SIZE
assert images is None or len(images) <= 1, f"not support multi images by now."
history = history or []

text = f"Question: {query} Answer: "
text += starts_with if starts_with is not None else ""

input_ids = [tokenizer.bos_token_id]
token_type_ids = [0]
if images is not None and len(images) == 1:
# vision
transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]
)
images = [transform(images[0])]
vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
input_ids += [tokenizer.pad_token_id] * vision_token_num
token_type_ids += [1] * vision_token_num
text_ids = tokenizer.encode(text, add_special_tokens=False)

input_ids += text_ids
token_type_ids += [0] * len(text_ids)
attention_mask = [1] * len(input_ids)

return {
'input_ids': torch.tensor(input_ids, dtype=torch.long),
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
'images': images,
}

def image_generator(image_dir) -> Generator[str, None, None]:
for root, dirs, files in os.walk(image_dir):
for file in files:
if any([file.endswith(ext) for ext in SUPPORTED_EXT]):
yield os.path.join(root, file)
def image_generator(image_dir: str, do_recurse: bool = True) -> Generator[str, None, None]:
if do_recurse:
for root, dirs, files in os.walk(image_dir):
for file in files:
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
yield os.path.join(root, file)
else:
for file in os.listdir(image_dir):
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
yield os.path.join(image_dir, file)

def get_gpu_memory_map():
nvmlInit()
Expand All @@ -44,13 +104,27 @@ def get_gpu_memory_map():
nvmlShutdown()
return info.used/1024/1024

def save_params(args, gen_kwargs):
save_path = os.path.join(args.image_dir, "caption_cog_params.txt")
args_dict = {
"args": vars(args),
"gen_kwargs": gen_kwargs,
}
pretty_print = json.dumps(args_dict, indent=4)
with open(save_path, "w") as f:
f.write(pretty_print)


def main(args):
prompt_plugin_fn = load_prompt_alteration_plugin(args.prompt_plugin, args=args)

tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
model = AutoModelForCausalLM.from_pretrained(
'THUDM/cogvlm-chat-hf',
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
trust_remote_code=True, # gee hope they don't get hacked or have a bad internal actor
#revision=... # no one is actually doing this
load_in_4bit=not args.disable_4bit,
)

Expand All @@ -61,8 +135,8 @@ def main(args):
args.temp = args.temp or 1.0

args.append = args.append or ""
if len(args.append) > 0 and not args.append.startswith(" "):
args.append = " " + args.append
if len(args.append) > 0:
args.append = " " + args.append.strip()

gen_kwargs = {
"max_length": args.max_length,
Expand All @@ -80,52 +154,61 @@ def main(args):
}

if args.max_new_tokens is not None:
print(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length")
logging.info(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length")
del gen_kwargs["max_length"]

if not do_sample:
print(f"** Using greedy search instead sampling. Generated captions will be deterministic; meaning it will be the same even if you run this program multiple times.")
logging.info(f"** Using greedy sampling")
del gen_kwargs["top_k"]
del gen_kwargs["top_p"]
del gen_kwargs["temperature"]
else:
print(f"** Sampling enabled")
logging.info(f"** Sampling enabled")

force_words_ids = None
if args.force_words is not None:
force_words = args.force_words.split(",") if args.force_words is not None else []
print(f"** force_words: {Fore.LIGHTGREEN_EX}{force_words}{Style.RESET_ALL}")
logging.info(f"** force_words: {Fore.LIGHTGREEN_EX}{force_words}{Style.RESET_ALL}")
force_words_ids = tokenizer(force_words, add_special_tokens=False)["input_ids"] if force_words else []

bad_words_ids = None
if args.bad_words is not None:
bad_words = args.bad_words.split(",") if args.bad_words is not None else []
print(f"** bad_words: {Fore.LIGHTGREEN_EX}{bad_words}{Style.RESET_ALL}")
logging.info(f"** bad_words: {Fore.LIGHTGREEN_EX}{bad_words}{Style.RESET_ALL}")
bad_words_ids = tokenizer(bad_words, add_special_tokens=False)["input_ids"] if bad_words else []

print(f"** gen_kwargs: \n{Fore.LIGHTGREEN_EX}{gen_kwargs}{Style.RESET_ALL}")
logging.info(f"** gen_kwargs: \n{Fore.LIGHTGREEN_EX}{gen_kwargs}{Style.RESET_ALL}")

save_params(args, gen_kwargs)

total_start_time = time.time()
i_processed = 0

for image_path in image_generator(args.image_dir):
starts_with = args.starts_with.strip()

for i, image_path in enumerate(image_generator(args.image_dir, do_recurse=not args.no_recurse)):
candidate_caption_path = image_path.replace(os.path.splitext(image_path)[-1], ".txt")

if args.no_overwrite and os.path.exists(candidate_caption_path):
print(f"Skipping {image_path}, caption already exists.")
logging.warning(f"Skipping {image_path}, caption already exists.")
continue

start_time = time.time()
cap_start_time = time.time()
image = Image.open(image_path)

try:
image = image.convert('RGB')
image = ImageOps.exif_transpose(image)
except Exception as e:
print(f"Non-fatal error processing {image_path}: {e}")
logging.warning(f"Non-fatal error processing {image_path}: {e}")
continue

logging.debug(f" __ Prompt before plugin: {Fore.LIGHTGREEN_EX}{args.prompt}{Style.RESET_ALL}")
prompt = prompt_plugin_fn(image_path, args=args)
logging.debug(f" __ Modified prompt after plugin: {Fore.LIGHTGREEN_EX}{prompt}{Style.RESET_ALL}")

inputs = build_conversation_input_ids(tokenizer, query=prompt, history=[], images=[image], starts_with=args.starts_with) # chat mode

inputs = model.build_conversation_input_ids(tokenizer, query=args.prompt, history=[], images=[image]) # chat mode
inputs = {
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
Expand All @@ -134,27 +217,53 @@ def main(args):
}

with torch.no_grad():
#input_decoded = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
#logging.debug(f"inputs decoded: {input_decoded}")
outputs = model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
outputs_without_prompt = outputs[:, inputs['input_ids'].shape[1]:]

len_inputs = inputs['input_ids'].shape[1]
outputs_without_prompt = outputs[:, len_inputs:]

caption = tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True)
if not args.remove_starts_with:
# deal with caption starting with comma, etc
if not re.match(r"^\W", caption):
caption = starts_with + " " + caption
else:
caption = starts_with + caption

caption += args.append

with open(candidate_caption_path, "w", encoding="utf-8") as f:
with open(candidate_caption_path, "w") as f:
f.write(caption)
vram_gb = get_gpu_memory_map()
elapsed_time = time.time() - start_time
print(f"VRAM: {Fore.LIGHTYELLOW_EX}{vram_gb:0.1f} GB{Style.RESET_ALL}, elapsed: {Fore.LIGHTYELLOW_EX}{elapsed_time:0.1f}{Style.RESET_ALL} sec, Captioned {Fore.LIGHTYELLOW_EX}{image_path}{Style.RESET_ALL}: ")
print(f"{Fore.LIGHTCYAN_EX}{caption}{Style.RESET_ALL}")
elapsed_time = time.time() - cap_start_time
logging.info(f"n:{i:05}, VRAM: {Fore.LIGHTYELLOW_EX}{vram_gb:0.1f} GB{Style.RESET_ALL}, elapsed: {Fore.LIGHTYELLOW_EX}{elapsed_time:0.1f}{Style.RESET_ALL} sec, Captioned {Fore.LIGHTYELLOW_EX}{image_path}{Style.RESET_ALL}: ")
logging.info(f"{Fore.LIGHTCYAN_EX}{caption}{Style.RESET_ALL}")
i_processed += 1

if i_processed == 0:
print(f"** No images found in {args.image_dir} with extension in {SUPPORTED_EXT} OR no images left to caption (did you use --no_overwrite?)")
logging.info(f"** No images found in {args.image_dir} with extension in {SUPPORTED_EXT} OR no images left to caption (did you use --no_overwrite?)")
exit(1)

total_elapsed_time = time.time() - total_start_time
avg_time = total_elapsed_time / i_processed
hh_mm_ss = time.strftime("%H:%M:%S", time.gmtime(total_elapsed_time))
print(f"** Done captioning {args.image_dir} with prompt '{args.prompt}', total elapsed: {hh_mm_ss} (hh_mm_ss), avg: {avg_time:0.1f} sec/image")
logging.info(f"** Done captioning {args.image_dir} with prompt '{prompt}', total elapsed: {hh_mm_ss} (hh_mm_ss), avg: {avg_time:0.1f} sec/image")


def configure_logging(args: argparse.Namespace):
level = logging.INFO if not args.debug else logging.DEBUG
filemode = "a" if args.append_log else "w"
logging.basicConfig(filename="caption_cog.log",
level=level,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
filemode=filemode)

console = logging.StreamHandler()
console.setLevel(level)
console.setFormatter(logging.Formatter('%(message)s'))
logging.getLogger('').addHandler(console)

EXAMPLES = """ex.
Basic example:
Expand Down Expand Up @@ -189,6 +298,7 @@ def main(args):

if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("--debug", action="store_true", help="Enable debug logging")
argparser.add_argument("--disable_4bit", action="store_true", help="Disables 4bit inference for compatibility or experimentation. Bad for VRAM, fallback is bf16.")
argparser.add_argument("--temp", type=float, default=None, help="Temperature for sampling")
argparser.add_argument("--num_beams", type=int, default=2, help="Number of beams for beam search, default 1 (off)")
Expand All @@ -206,26 +316,34 @@ def main(args):
argparser.add_argument("--force_words", type=str, default=None, help="Forces the model to include these words in the caption, use CSV format.")
argparser.add_argument("--bad_words", type=str, default=None, help="Words that will not be allowed, use CSV format.")
argparser.add_argument("--append", type=str, default=None, help="Extra string to append to all captions. ex. 'painted by John Doe'")
argparser.add_argument("--no_recurse", action="store_true", help="Do not recurse into subdirectories.")
argparser.add_argument("--prompt_plugin", type=str, default=None, help="Function name to modify prompt, edit code to add plugins.")
argparser.add_argument("--starts_with", type=str, default=None, help="Force start words on the output caption.")
argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.")
argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.")
args = argparser.parse_args()

configure_logging(args)

print(DESCRIPTION)
print(EXAMPLES)

if args.top_k is not None or args.top_p is not None or args.temp is not None:
print(f"** Sampling enabled.")
args.sampling = True
args.top_k = args.top_k or 50
args.top_p = args.top_p or 1.0
args.temp = args.temp or 1.0

if args.image_dir is None:
print(f"** {Fore.RED}Error: image_dir is required.{Style.RESET_ALL}")
logging.error(f"** {Fore.RED}Error: image_dir is required.{Style.RESET_ALL}")
exit(1)

if not os.path.exists(args.image_dir):
print(f"** {Fore.RED}Error: image_dir {args.image_dir} does not exist.{Style.RESET_ALL}")
logging.error(f"** {Fore.RED}Error: image_dir {args.image_dir} does not exist.{Style.RESET_ALL}")
exit(1)

print(f"** Running: {args.image_dir} with prompt '{args.prompt}'")
startprint = f"** Running: {args.image_dir} with prompt '{args.prompt}"
if args.starts_with is not None:
startprint += f" {args.starts_with}'"
else:
startprint += "'"
startprint += f" <caption>"
if args.append is not None:
startprint += f", and appending: {args.append}"
logging.info(startprint)

main(args)
Empty file added plugins/__init__.py
Empty file.
Loading

0 comments on commit d098223

Please sign in to comment.