|
|
|
|
|
|
|
""" |
|
JoyCaption Alpha One |
|
|
|
This module provides functionality for generating captions for images using a |
|
combination of CLIP, LLM, and custom image adapters. It supports various |
|
caption types, tones, and lengths. |
|
|
|
The main components include: |
|
- Loading and initializing models (CLIP, LLM, image adapter) |
|
- Processing images and generating captions |
|
- Command-line interface for batch processing images in a directory |
|
""" |
|
|
|
import os |
|
import argparse |
|
import re |
|
import random |
|
from pathlib import Path |
|
from PIL import Image |
|
import pillow_jxl |
|
import torch |
|
import torchvision.transforms.functional as TVF |
|
from transformers import ( |
|
AutoModel, |
|
AutoProcessor, |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
PreTrainedTokenizer, |
|
PreTrainedTokenizerFast, |
|
) |
|
from torch import nn |
|
|
|
CLIP_PATH = "google/siglip-so400m-patch14-384" |
|
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B" |
|
CHECKPOINT_PATH = Path(__file__).resolve().parent / "9em124t2-499968" |
|
CAPTION_TYPE_MAP = { |
|
("descriptive", "formal", False, False): [ |
|
"Write a descriptive caption for this image in a formal tone." |
|
], |
|
("descriptive", "formal", False, True): [ |
|
"Write a descriptive caption for this image in a formal tone within " |
|
"{word_count} words." |
|
], |
|
("descriptive", "formal", True, False): [ |
|
"Write a {length} descriptive caption for this image in a formal tone." |
|
], |
|
("descriptive", "informal", False, False): [ |
|
"Write a descriptive caption for this image in a casual tone." |
|
], |
|
("descriptive", "informal", False, True): [ |
|
"Write a descriptive caption for this image in a casual tone within " |
|
"{word_count} words." |
|
], |
|
("descriptive", "informal", True, False): [ |
|
"Write a {length} descriptive caption for this image in a casual tone." |
|
], |
|
("training_prompt", "formal", False, False): [ |
|
"Write a stable diffusion prompt for this image." |
|
], |
|
("training_prompt", "formal", False, True): [ |
|
"Write a stable diffusion prompt for this image within {word_count} " |
|
"words." |
|
], |
|
("training_prompt", "formal", True, False): [ |
|
"Write a {length} stable diffusion prompt for this image." |
|
], |
|
("rng-tags", "formal", False, False): [ |
|
"Write a list of Booru tags for this image." |
|
], |
|
("rng-tags", "formal", False, True): [ |
|
"Write a list of Booru tags for this image within {word_count} words." |
|
], |
|
("rng-tags", "formal", True, False): [ |
|
"Write a {length} list of Booru tags for this image." |
|
], |
|
} |
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", None) |
|
|
|
class ImageAdapter(nn.Module): |
|
""" |
|
Custom image adapter module for processing CLIP vision outputs. |
|
|
|
This module adapts the output of a CLIP vision model to be compatible with |
|
a text model. It supports optional layer normalization, positional |
|
embeddings, and deep feature extraction. |
|
|
|
Args: |
|
input_features (int): Number of input features from the vision model. |
|
output_features (int): Number of output features to match the text model. |
|
ln1 (bool): Whether to use layer normalization. |
|
pos_emb (bool): Whether to use positional embeddings. |
|
num_image_tokens (int): Number of image tokens. |
|
deep_extract (bool): Whether to use deep feature extraction. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_features: int, |
|
output_features: int, |
|
ln1: bool, |
|
pos_emb: bool, |
|
num_image_tokens: int, |
|
deep_extract: bool, |
|
): |
|
super().__init__() |
|
self.deep_extract = deep_extract |
|
|
|
if self.deep_extract: |
|
input_features = input_features * 5 |
|
|
|
self.linear1 = nn.Linear(input_features, output_features) |
|
self.activation = nn.GELU() |
|
self.linear2 = nn.Linear(output_features, output_features) |
|
self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features) |
|
self.pos_emb = None if not pos_emb else nn.Parameter( |
|
torch.zeros(num_image_tokens, input_features) |
|
) |
|
|
|
self.other_tokens = nn.Embedding(3, output_features) |
|
self.other_tokens.weight.data.normal_(mean=0.0, std=0.02) |
|
|
|
def forward(self, vision_outputs: torch.Tensor): |
|
""" |
|
Forward pass of the image adapter. |
|
|
|
Args: |
|
vision_outputs (torch.Tensor): Output tensor from the CLIP vision model. |
|
|
|
Returns: |
|
torch.Tensor: Adapted image features. |
|
""" |
|
if self.deep_extract: |
|
x = torch.concat(( |
|
vision_outputs[-2], |
|
vision_outputs[3], |
|
vision_outputs[7], |
|
vision_outputs[13], |
|
vision_outputs[20], |
|
), dim=-1) |
|
assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}" |
|
assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, ( |
|
f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}" |
|
) |
|
else: |
|
x = vision_outputs[-2] |
|
|
|
x = self.ln1(x) |
|
|
|
if self.pos_emb is not None: |
|
assert x.shape[-2:] == self.pos_emb.shape, ( |
|
f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}" |
|
) |
|
x = x + self.pos_emb |
|
|
|
x = self.linear1(x) |
|
x = self.activation(x) |
|
x = self.linear2(x) |
|
|
|
other_tokens = self.other_tokens( |
|
torch.tensor([0, 1], device=self.other_tokens.weight.device).expand( |
|
x.shape[0], -1 |
|
) |
|
) |
|
assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), ( |
|
f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}" |
|
) |
|
x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1) |
|
|
|
return x |
|
|
|
def get_eot_embedding(self): |
|
""" |
|
Get the end-of-text embedding. |
|
|
|
Returns: |
|
torch.Tensor: The end-of-text embedding. |
|
""" |
|
return self.other_tokens( |
|
torch.tensor([2], device=self.other_tokens.weight.device) |
|
).squeeze(0) |
|
|
|
class JoyCaptionModel: |
|
""" |
|
A class for generating captions for images using CLIP, LLM, and custom image adapters. |
|
|
|
This class encapsulates the functionality to load and initialize various models |
|
(CLIP, LLM, image adapter) and use them to process images and generate captions. |
|
It supports different caption types, tones, and lengths. |
|
|
|
Attributes: |
|
clip_model: The CLIP vision model for processing images. |
|
text_model: The language model for generating captions. |
|
image_adapter: Custom adapter for processing CLIP vision outputs. |
|
tokenizer: Tokenizer for the language model. |
|
|
|
Methods: |
|
load_models(): Load and initialize all required models. |
|
process_image(input_image, caption_type, caption_tone, caption_length): |
|
Process an input image and generate a caption based on specified parameters. |
|
""" |
|
|
|
def __init__(self): |
|
self.clip_model = None |
|
self.text_model = None |
|
self.image_adapter = None |
|
self.tokenizer = None |
|
|
|
def load_models(self): |
|
""" |
|
Load and initialize all required models (CLIP, LLM, image adapter). |
|
""" |
|
print("Loading CLIP") |
|
self.clip_model = AutoModel.from_pretrained(CLIP_PATH) |
|
self.clip_model = self.clip_model.vision_model |
|
|
|
if (CHECKPOINT_PATH / "clip_model.pt").exists(): |
|
print("Loading VLM's custom vision model") |
|
checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu') |
|
checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()} |
|
self.clip_model.load_state_dict(checkpoint) |
|
del checkpoint |
|
|
|
self.clip_model.eval() |
|
self.clip_model.requires_grad_(False) |
|
self.clip_model.to("cuda") |
|
|
|
print("Loading tokenizer") |
|
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False) |
|
assert isinstance(self.tokenizer, PreTrainedTokenizer) or isinstance( |
|
self.tokenizer, PreTrainedTokenizerFast |
|
), f"Tokenizer is of type {type(self.tokenizer)}" |
|
|
|
print("Loading LLM") |
|
if (CHECKPOINT_PATH / "text_model").exists(): |
|
print("Loading VLM's custom text model") |
|
self.text_model = AutoModelForCausalLM.from_pretrained( |
|
CHECKPOINT_PATH / "text_model", |
|
device_map=0, |
|
torch_dtype=torch.bfloat16 |
|
) |
|
else: |
|
self.text_model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_PATH, |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16 |
|
) |
|
|
|
self.text_model.eval() |
|
|
|
print("Loading image adapter") |
|
self.image_adapter = ImageAdapter( |
|
self.clip_model.config.hidden_size, |
|
self.text_model.config.hidden_size, |
|
False, |
|
False, |
|
38, |
|
False |
|
) |
|
self.image_adapter.load_state_dict( |
|
torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu") |
|
) |
|
self.image_adapter.eval() |
|
self.image_adapter.to("cuda") |
|
|
|
@torch.no_grad() |
|
def process_image(self, |
|
input_image: Image.Image, |
|
caption_type: str, |
|
caption_tone: str, |
|
caption_length: str | int, |
|
custom_prompt: str = None) -> str: |
|
""" |
|
Process an input image and generate a caption based on specified parameters. |
|
""" |
|
torch.cuda.empty_cache() |
|
|
|
if caption_type == "custom" and custom_prompt: |
|
prompt_str = custom_prompt |
|
else: |
|
prompt_str = self._get_prompt_string(caption_type, caption_tone, caption_length) |
|
print(f"Prompt: {prompt_str}") |
|
|
|
pixel_values = self._preprocess_image(input_image) |
|
prompt = self._tokenize_prompt(prompt_str) |
|
|
|
embedded_images = self._embed_image(pixel_values) |
|
inputs_embeds, input_ids, attention_mask = self._construct_inputs(embedded_images, prompt) |
|
|
|
generate_ids = self._generate_caption(inputs_embeds, input_ids, attention_mask) |
|
caption = self._decode_caption(generate_ids, input_ids) |
|
|
|
return caption.strip() |
|
|
|
def _get_prompt_string(self, caption_type, caption_tone, caption_length): |
|
length = None if caption_length == "any" else caption_length |
|
|
|
if isinstance(length, str): |
|
try: |
|
length = int(length) |
|
except ValueError: |
|
pass |
|
|
|
if caption_type in {"rng-tags", "training_prompt"}: |
|
caption_tone = "formal" |
|
|
|
prompt_key = ( |
|
caption_type, |
|
caption_tone, |
|
isinstance(length, str), |
|
isinstance(length, int) |
|
) |
|
if prompt_key not in CAPTION_TYPE_MAP: |
|
raise ValueError(f"Invalid caption type: {prompt_key}") |
|
|
|
prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format( |
|
length=length, word_count=length |
|
) |
|
return prompt_str |
|
|
|
def _preprocess_image(self, input_image): |
|
image = input_image.resize((384, 384), Image.LANCZOS) |
|
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0 |
|
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5]) |
|
pixel_values = pixel_values.to('cuda') |
|
return pixel_values |
|
|
|
def _tokenize_prompt(self, prompt_str): |
|
prompt = self.tokenizer.encode( |
|
prompt_str, |
|
return_tensors='pt', |
|
padding=False, |
|
truncation=False, |
|
add_special_tokens=False |
|
) |
|
return prompt |
|
|
|
def _embed_image(self, pixel_values): |
|
with torch.amp.autocast_mode.autocast('cuda', enabled=True): |
|
vision_outputs = self.clip_model(pixel_values=pixel_values, output_hidden_states=True) |
|
image_features = vision_outputs.hidden_states |
|
embedded_images = self.image_adapter(image_features) |
|
embedded_images = embedded_images.to('cuda') |
|
return embedded_images |
|
|
|
def _construct_inputs(self, embedded_images, prompt): |
|
prompt_embeds = self.text_model.model.embed_tokens(prompt.to('cuda')) |
|
assert prompt_embeds.shape == (1, prompt.shape[1], self.text_model.config.hidden_size), ( |
|
f"Prompt shape is {prompt_embeds.shape}, expected " |
|
f"{(1, prompt.shape[1], self.text_model.config.hidden_size)}" |
|
) |
|
|
|
embedded_bos = self.text_model.model.embed_tokens( |
|
torch.tensor([[self.tokenizer.bos_token_id]], |
|
device=self.text_model.device, |
|
dtype=torch.int64) |
|
) |
|
|
|
eot_embed = self.image_adapter.get_eot_embedding().unsqueeze(0).to( |
|
dtype=self.text_model.dtype |
|
) |
|
|
|
inputs_embeds = torch.cat([ |
|
embedded_bos.expand(embedded_images.shape[0], -1, -1), |
|
embedded_images.to(dtype=embedded_bos.dtype), |
|
prompt_embeds.expand(embedded_images.shape[0], -1, -1), |
|
eot_embed.expand(embedded_images.shape[0], -1, -1), |
|
], dim=1) |
|
|
|
input_ids = torch.cat([ |
|
torch.tensor([[self.tokenizer.bos_token_id]], dtype=torch.long), |
|
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), |
|
prompt, |
|
torch.tensor([[self.tokenizer.eos_token_id]], dtype=torch.long), |
|
], dim=1).to('cuda') |
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
return inputs_embeds, input_ids, attention_mask |
|
|
|
def _generate_caption(self, inputs_embeds, input_ids, attention_mask): |
|
generate_ids = self.text_model.generate( |
|
input_ids, |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
max_new_tokens=300, |
|
do_sample=True, |
|
suppress_tokens=None |
|
) |
|
return generate_ids |
|
|
|
def _decode_caption(self, generate_ids, input_ids): |
|
generate_ids = generate_ids[:, input_ids.shape[1]:] |
|
|
|
if (generate_ids[0][-1] == self.tokenizer.eos_token_id or |
|
generate_ids[0][-1] == self.tokenizer.convert_tokens_to_ids("<|eot_id|>")): |
|
generate_ids = generate_ids[:, :-1] |
|
|
|
caption = self.tokenizer.batch_decode( |
|
generate_ids, |
|
skip_special_tokens=False, |
|
clean_up_tokenization_spaces=False |
|
)[0] |
|
return caption |
|
|
|
|
|
def main(): |
|
"""Generate captions for images in a directory and save them as .caption files.""" |
|
parser = argparse.ArgumentParser( |
|
description="Generate captions for images in a directory and save them as .caption files." |
|
) |
|
parser.add_argument("directory", type=str, help="Target directory containing images.") |
|
parser.add_argument( |
|
"--caption_type", |
|
type=str, |
|
default="descriptive", |
|
choices=["descriptive", "training_prompt", "rng-tags", "custom"], |
|
help="Type of caption to generate." |
|
) |
|
parser.add_argument( |
|
"--caption_tone", |
|
type=str, |
|
default="formal", |
|
choices=["formal", "informal"], |
|
help="Tone of the caption." |
|
) |
|
parser.add_argument( |
|
"--caption_length", |
|
type=str, |
|
default="any", |
|
help="Length of the caption." |
|
) |
|
parser.add_argument( |
|
"--dont-strip-commas", |
|
action="store_true", |
|
help="If set, commas will not be stripped from the generated captions." |
|
) |
|
parser.add_argument( |
|
"--custom_prompt", |
|
type=str, |
|
help="Custom prompt for the captioner. Use with --caption_type custom." |
|
) |
|
parser.add_argument( |
|
'--add-commas-to-sentence-ends', |
|
action='store_true', |
|
help='Add commas after periods in sentences' |
|
) |
|
parser.add_argument( |
|
'--feed-from-tags', |
|
type=int, |
|
nargs='?', |
|
const=-1, |
|
help='Use .txt files with the same base filename as the images as input to the captioner. Optionally specify the number of tags to use.' |
|
) |
|
parser.add_argument( |
|
'--random-tags', |
|
type=int, |
|
help='Randomly select n number of tags. Only works if --feed-from-tags is enabled.' |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.random_tags is not None and args.feed_from_tags is None: |
|
parser.error("--random-tags can only be used when --feed-from-tags is enabled") |
|
|
|
|
|
joy_caption_model = JoyCaptionModel() |
|
joy_caption_model.load_models() |
|
|
|
|
|
if args.caption_type == "custom" and not args.custom_prompt: |
|
parser.error("--custom_prompt is required when using --caption_type custom") |
|
elif args.caption_type != "custom" and args.custom_prompt: |
|
parser.error("--custom_prompt can only be used with --caption_type custom") |
|
|
|
image_extensions = {".webp", ".png", ".jpeg", ".jpg", ".jxl"} |
|
for image_path in Path(args.directory).rglob("*"): |
|
if image_path.suffix.lower() in image_extensions: |
|
caption_file = image_path.with_suffix('.caption') |
|
|
|
|
|
if caption_file.exists(): |
|
print(f"Skipping {image_path}: Caption file already exists.") |
|
continue |
|
|
|
input_image = Image.open(image_path).convert("RGB") |
|
|
|
|
|
if args.caption_type == "custom": |
|
caption = joy_caption_model.process_image( |
|
input_image, |
|
"custom", |
|
args.caption_tone, |
|
args.caption_length, |
|
custom_prompt=args.custom_prompt |
|
) |
|
else: |
|
|
|
if args.feed_from_tags is not None: |
|
tag_file = find_tag_file(image_path) |
|
if tag_file: |
|
with open(tag_file, 'r', encoding='utf-8') as f: |
|
tags = f.read().strip().split(',') |
|
|
|
if args.random_tags is not None: |
|
|
|
num_tags = min(args.random_tags, len(tags)) |
|
tags = random.sample(tags, num_tags) |
|
elif args.feed_from_tags > 0: |
|
|
|
tags = tags[:args.feed_from_tags] |
|
|
|
tag_string = ', '.join(tags) |
|
custom_prompt = f"Write a descriptive caption for this image in a formal tone. Use these tags as context clues to construct your caption: {tag_string}" |
|
|
|
caption = joy_caption_model.process_image( |
|
input_image, |
|
"custom", |
|
args.caption_tone, |
|
args.caption_length, |
|
custom_prompt=custom_prompt |
|
) |
|
else: |
|
caption = joy_caption_model.process_image( |
|
input_image, |
|
args.caption_type, |
|
args.caption_tone, |
|
args.caption_length |
|
) |
|
else: |
|
caption = joy_caption_model.process_image( |
|
input_image, |
|
args.caption_type, |
|
args.caption_tone, |
|
args.caption_length |
|
) |
|
|
|
|
|
if not args.dont_strip_commas: |
|
|
|
caption = re.sub(r',\s*([^\d])', r' \1', caption) |
|
|
|
|
|
if args.add_commas_to_sentence_ends: |
|
caption = re.sub(r'(\.)(\s+)([A-Z])', r'\1,\2\3', caption) |
|
|
|
print(f"Caption for {image_path}:\n\n{caption}\n\n") |
|
|
|
|
|
with open(caption_file, 'w', encoding='utf-8') as f: |
|
f.write(caption) |
|
print(f"Caption saved to {caption_file}") |
|
|
|
def find_tag_file(image_path): |
|
""" |
|
Find the corresponding .txt file for the given image path. |
|
Handles cases where the image has a -(number) suffix. |
|
""" |
|
base_name = image_path.stem |
|
tag_file = image_path.with_suffix('.txt') |
|
|
|
if tag_file.exists(): |
|
return tag_file |
|
|
|
|
|
match = re.match(r'(.+)-\d+$', base_name) |
|
if match: |
|
base_name = match.group(1) |
|
tag_file = image_path.with_name(base_name).with_suffix('.txt') |
|
if tag_file.exists(): |
|
return tag_file |
|
|
|
return None |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|