Code health maintenance + generate_valid_caption
Browse files
joy
CHANGED
@@ -19,6 +19,7 @@ import argparse
|
|
19 |
import re
|
20 |
import random
|
21 |
from pathlib import Path
|
|
|
22 |
from PIL import Image
|
23 |
import pillow_jxl
|
24 |
import torch
|
@@ -32,7 +33,6 @@ from transformers import (
|
|
32 |
)
|
33 |
from torch import nn
|
34 |
from e6db_reader import TagSetNormalizer, tag_category2id, tag_rank_to_freq
|
35 |
-
from typing import List, Tuple, Dict
|
36 |
|
37 |
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
38 |
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
|
@@ -62,7 +62,8 @@ CAPTION_TYPE_MAP = {
|
|
62 |
"Write a stable diffusion prompt for this image."
|
63 |
],
|
64 |
("training_prompt", "formal", False, True): [
|
65 |
-
"Write a stable diffusion prompt for this image within
|
|
|
66 |
],
|
67 |
("training_prompt", "formal", True, False): [
|
68 |
"Write a {length} stable diffusion prompt for this image."
|
@@ -90,12 +91,18 @@ class ImageAdapter(nn.Module):
|
|
90 |
embeddings, and deep feature extraction.
|
91 |
|
92 |
Args:
|
93 |
-
input_features (int):
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
"""
|
100 |
|
101 |
def __init__(
|
@@ -131,7 +138,8 @@ class ImageAdapter(nn.Module):
|
|
131 |
Forward pass of the image adapter.
|
132 |
|
133 |
Args:
|
134 |
-
vision_outputs (torch.Tensor):
|
|
|
135 |
|
136 |
Returns:
|
137 |
torch.Tensor: Adapted image features.
|
@@ -148,9 +156,10 @@ class ImageAdapter(nn.Module):
|
|
148 |
dim=-1,
|
149 |
)
|
150 |
assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}"
|
|
|
151 |
assert (
|
152 |
-
x.shape[-1] ==
|
153 |
-
), f"Expected {
|
154 |
else:
|
155 |
x = vision_outputs[-2]
|
156 |
|
@@ -167,9 +176,8 @@ class ImageAdapter(nn.Module):
|
|
167 |
x = self.linear2(x)
|
168 |
|
169 |
other_tokens = self.other_tokens(
|
170 |
-
torch.tensor([0, 1], device=self.other_tokens.weight.device)
|
171 |
-
|
172 |
-
)
|
173 |
)
|
174 |
assert other_tokens.shape == (
|
175 |
x.shape[0],
|
@@ -194,10 +202,13 @@ class ImageAdapter(nn.Module):
|
|
194 |
|
195 |
class JoyCaptionModel:
|
196 |
"""
|
197 |
-
A class for generating captions for images using CLIP, LLM,
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
-
This class encapsulates the functionality to load and initialize various models
|
200 |
-
(CLIP, LLM, image adapter) and use them to process images and generate captions.
|
201 |
It supports different caption types, tones, and lengths.
|
202 |
|
203 |
Attributes:
|
@@ -209,7 +220,8 @@ class JoyCaptionModel:
|
|
209 |
Methods:
|
210 |
load_models(): Load and initialize all required models.
|
211 |
process_image(input_image, caption_type, caption_tone, caption_length):
|
212 |
-
Process an input image and generate a caption
|
|
|
213 |
"""
|
214 |
|
215 |
def __init__(self):
|
@@ -232,7 +244,8 @@ class JoyCaptionModel:
|
|
232 |
CHECKPOINT_PATH / "clip_model.pt", map_location="cpu"
|
233 |
)
|
234 |
checkpoint = {
|
235 |
-
k.replace("_orig_mod.module.", ""): v
|
|
|
236 |
}
|
237 |
self.clip_model.load_state_dict(checkpoint)
|
238 |
del checkpoint
|
@@ -242,16 +255,20 @@ class JoyCaptionModel:
|
|
242 |
self.clip_model.to("cuda")
|
243 |
|
244 |
print("Loading tokenizer")
|
245 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
246 |
-
|
247 |
-
|
248 |
-
|
|
|
|
|
249 |
|
250 |
print("Loading LLM")
|
251 |
if (CHECKPOINT_PATH / "text_model").exists():
|
252 |
print("Loading VLM's custom text model")
|
253 |
self.text_model = AutoModelForCausalLM.from_pretrained(
|
254 |
-
CHECKPOINT_PATH / "text_model",
|
|
|
|
|
255 |
)
|
256 |
else:
|
257 |
self.text_model = AutoModelForCausalLM.from_pretrained(
|
@@ -270,7 +287,10 @@ class JoyCaptionModel:
|
|
270 |
False,
|
271 |
)
|
272 |
self.image_adapter.load_state_dict(
|
273 |
-
torch.load(
|
|
|
|
|
|
|
274 |
)
|
275 |
self.image_adapter.eval()
|
276 |
self.image_adapter.to("cuda")
|
@@ -285,7 +305,8 @@ class JoyCaptionModel:
|
|
285 |
custom_prompt: str | None = None,
|
286 |
) -> str:
|
287 |
"""
|
288 |
-
Process an input image and generate a caption based on specified
|
|
|
289 |
"""
|
290 |
torch.cuda.empty_cache()
|
291 |
|
@@ -305,11 +326,39 @@ class JoyCaptionModel:
|
|
305 |
embedded_images, prompt
|
306 |
)
|
307 |
|
308 |
-
generate_ids = self._generate_caption(inputs_embeds,
|
|
|
|
|
309 |
caption = self._decode_caption(generate_ids, input_ids)
|
310 |
|
311 |
return caption.strip()
|
312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
def _get_prompt_string(self, caption_type, caption_tone, caption_length):
|
314 |
length = None if caption_length == "any" else caption_length
|
315 |
|
@@ -400,10 +449,16 @@ class JoyCaptionModel:
|
|
400 |
|
401 |
input_ids = torch.cat(
|
402 |
[
|
403 |
-
torch.tensor(
|
404 |
-
|
|
|
|
|
|
|
|
|
405 |
prompt,
|
406 |
-
torch.tensor(
|
|
|
|
|
407 |
],
|
408 |
dim=1,
|
409 |
).to("cuda")
|
@@ -423,23 +478,31 @@ class JoyCaptionModel:
|
|
423 |
return generate_ids
|
424 |
|
425 |
def _decode_caption(self, generate_ids, input_ids):
|
426 |
-
generate_ids = generate_ids[:, input_ids.shape[1]
|
427 |
|
428 |
-
if generate_ids[0][-1] == self.tokenizer.eos_token_id or
|
429 |
-
-1
|
430 |
-
|
431 |
generate_ids = generate_ids[:, :-1]
|
432 |
|
433 |
caption = self.tokenizer.batch_decode(
|
434 |
-
generate_ids,
|
|
|
|
|
435 |
)[0]
|
436 |
return caption
|
437 |
|
438 |
|
439 |
def main():
|
440 |
-
"""
|
|
|
|
|
|
|
441 |
parser = argparse.ArgumentParser(
|
442 |
-
description=
|
|
|
|
|
|
|
443 |
)
|
444 |
parser.add_argument(
|
445 |
"directory", type=str, help="Target directory containing images."
|
@@ -459,17 +522,25 @@ def main():
|
|
459 |
help="Tone of the caption.",
|
460 |
)
|
461 |
parser.add_argument(
|
462 |
-
"--caption_length",
|
|
|
|
|
|
|
463 |
)
|
464 |
parser.add_argument(
|
465 |
"--dont-strip-commas",
|
466 |
action="store_true",
|
467 |
-
help=
|
|
|
|
|
468 |
)
|
469 |
parser.add_argument(
|
470 |
"--custom_prompt",
|
471 |
type=str,
|
472 |
-
help=
|
|
|
|
|
|
|
473 |
)
|
474 |
parser.add_argument(
|
475 |
"--add-commas-to-sentence-ends",
|
@@ -481,19 +552,28 @@ def main():
|
|
481 |
type=int,
|
482 |
nargs="?",
|
483 |
const=-1,
|
484 |
-
help=
|
|
|
|
|
|
|
|
|
485 |
)
|
486 |
parser.add_argument(
|
487 |
"--random-tags",
|
488 |
type=int,
|
489 |
-
help=
|
|
|
|
|
|
|
490 |
)
|
491 |
|
492 |
args = parser.parse_args()
|
493 |
|
494 |
# Validate random-tags usage
|
495 |
if args.random_tags is not None and args.feed_from_tags is None:
|
496 |
-
parser.error(
|
|
|
|
|
497 |
|
498 |
print("Loading e621 tag data")
|
499 |
tagset_normalizer = make_tagset_normalizer()
|
@@ -504,9 +584,13 @@ def main():
|
|
504 |
|
505 |
# Validate custom prompt usage
|
506 |
if args.caption_type == "custom" and not args.custom_prompt:
|
507 |
-
parser.error(
|
|
|
|
|
508 |
elif args.caption_type != "custom" and args.custom_prompt:
|
509 |
-
parser.error(
|
|
|
|
|
510 |
|
511 |
image_extensions = {".webp", ".png", ".jpeg", ".jpg", ".jxl"}
|
512 |
for image_path in Path(args.directory).rglob("*"):
|
@@ -525,11 +609,13 @@ def main():
|
|
525 |
if args.caption_type == "custom":
|
526 |
custom_prompt = args.custom_prompt
|
527 |
elif args.feed_from_tags is not None:
|
528 |
-
custom_prompt = prompt_from_tags(
|
|
|
|
|
529 |
|
530 |
print(f"Custom prompt: {custom_prompt}")
|
531 |
|
532 |
-
caption = joy_caption_model.
|
533 |
input_image,
|
534 |
args.caption_type,
|
535 |
args.caption_tone,
|
@@ -611,16 +697,19 @@ def make_tagset_normalizer():
|
|
611 |
return tagset_normalizer.map_inputs(input_map, on_conflict="ignore")
|
612 |
|
613 |
|
614 |
-
def format_nl_list(
|
615 |
-
|
|
|
|
|
|
|
616 |
assert n > 0
|
617 |
if n == 1:
|
618 |
-
return
|
619 |
-
|
620 |
-
return f"{
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
|
625 |
|
626 |
TAG_SPECIES = tag_category2id["species"]
|
@@ -631,14 +720,17 @@ TAG_META = tag_category2id["meta"]
|
|
631 |
TAG_FREQ_THRESH = 0
|
632 |
|
633 |
|
634 |
-
def prompt_from_tags(args, image_path: Path,
|
|
|
635 |
"""
|
636 |
Generates a prompt from tags associated with the given image.
|
637 |
|
638 |
Args:
|
639 |
args: Additional arguments for the function.
|
640 |
-
image_path (Path):
|
641 |
-
|
|
|
|
|
642 |
|
643 |
Returns:
|
644 |
None
|
@@ -655,7 +747,8 @@ def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer
|
|
655 |
|
656 |
# These lists contain tuples (freq, tag, tag_id)
|
657 |
tag_by_category: Dict[int, List[Tuple[int, str, int]]] = {
|
658 |
-
cat: []
|
|
|
659 |
}
|
660 |
other_tags: List[Tuple[int, str, int]] = []
|
661 |
implied: set = set()
|
@@ -664,8 +757,8 @@ def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer
|
|
664 |
# Encode the tag into a numerical id
|
665 |
tag_id = encode(tag.replace(" ", "_"))
|
666 |
if tag_id is None:
|
667 |
-
other_tags.append((0, tag,
|
668 |
-
implied.update(tagset_normalizer.implications_rej.get(
|
669 |
continue
|
670 |
# Get the category of the tag
|
671 |
cat_id = tag_id_to_cat_id[tag_id]
|
@@ -677,13 +770,16 @@ def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer
|
|
677 |
freq = tag_rank_to_freq(tag_id)
|
678 |
if freq < TAG_FREQ_THRESH:
|
679 |
continue
|
680 |
-
tag_by_category.get(cat_id, other_tags).append(
|
|
|
|
|
681 |
|
682 |
other_tags = sorted(
|
683 |
(int(freq), tag, tag_id)
|
684 |
for freq, tag, tag_id in other_tags
|
685 |
if tag_id not in implied
|
686 |
)
|
|
|
687 |
for cat_id, cat_list in tag_by_category.items():
|
688 |
tag_by_category[cat_id] = sorted(
|
689 |
(int(freq), tag, tag_id)
|
@@ -696,8 +792,8 @@ def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer
|
|
696 |
num_tags = min(args.random_tags, len(other_tags))
|
697 |
other_tags = random.sample(
|
698 |
[
|
699 |
-
(i, tag,
|
700 |
-
for i, tag
|
701 |
],
|
702 |
num_tags,
|
703 |
)
|
@@ -713,25 +809,30 @@ def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer
|
|
713 |
artist_txt = f"by {format_nl_list(artist_list)}"
|
714 |
else:
|
715 |
artist_txt = ""
|
|
|
716 |
character_tag = tag_by_category[TAG_CHARACTER]
|
717 |
if character_tag:
|
718 |
tags = [tag for _, tag, _ in character_tag[:4]]
|
719 |
character_txt = f"named {format_nl_list(tags)}"
|
720 |
else:
|
721 |
character_txt = ""
|
|
|
722 |
species_tag = tag_by_category[TAG_SPECIES]
|
723 |
if species_tag:
|
724 |
-
species_txt =
|
|
|
|
|
|
|
|
|
725 |
species_txt += format_nl_list([tp[1] for tp in species_tag[:4]])
|
726 |
else:
|
727 |
if character_tag:
|
728 |
species_txt = (
|
729 |
-
" a character"
|
730 |
-
if len(character_tag) <= 1
|
731 |
-
else " characters"
|
732 |
)
|
733 |
else:
|
734 |
species_txt = ""
|
|
|
735 |
copyright_tag = tag_by_category[TAG_COPYRIGHT]
|
736 |
if copyright_tag:
|
737 |
tags = [tag for _, tag, *_ in copyright_tag[:4]]
|
|
|
19 |
import re
|
20 |
import random
|
21 |
from pathlib import Path
|
22 |
+
from typing import List, Tuple, Dict
|
23 |
from PIL import Image
|
24 |
import pillow_jxl
|
25 |
import torch
|
|
|
33 |
)
|
34 |
from torch import nn
|
35 |
from e6db_reader import TagSetNormalizer, tag_category2id, tag_rank_to_freq
|
|
|
36 |
|
37 |
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
38 |
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
|
|
|
62 |
"Write a stable diffusion prompt for this image."
|
63 |
],
|
64 |
("training_prompt", "formal", False, True): [
|
65 |
+
"Write a stable diffusion prompt for this image within " +
|
66 |
+
"{word_count} words."
|
67 |
],
|
68 |
("training_prompt", "formal", True, False): [
|
69 |
"Write a {length} stable diffusion prompt for this image."
|
|
|
91 |
embeddings, and deep feature extraction.
|
92 |
|
93 |
Args:
|
94 |
+
input_features (int):
|
95 |
+
Number of input features from the vision model.
|
96 |
+
output_features (int):
|
97 |
+
Number of output features to match the text model.
|
98 |
+
ln1 (bool):
|
99 |
+
Whether to use layer normalization.
|
100 |
+
pos_emb (bool):
|
101 |
+
Whether to use positional embeddings.
|
102 |
+
num_image_tokens (int):
|
103 |
+
Number of image tokens.
|
104 |
+
deep_extract (bool):
|
105 |
+
Whether to use deep feature extraction.
|
106 |
"""
|
107 |
|
108 |
def __init__(
|
|
|
138 |
Forward pass of the image adapter.
|
139 |
|
140 |
Args:
|
141 |
+
vision_outputs (torch.Tensor):
|
142 |
+
Output tensor from the CLIP vision model.
|
143 |
|
144 |
Returns:
|
145 |
torch.Tensor: Adapted image features.
|
|
|
156 |
dim=-1,
|
157 |
)
|
158 |
assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}"
|
159 |
+
expected_shape = vision_outputs[-2].shape[-1] * 5
|
160 |
assert (
|
161 |
+
x.shape[-1] == expected_shape
|
162 |
+
), f"Expected {expected_shape}, got {x.shape[-1]}"
|
163 |
else:
|
164 |
x = vision_outputs[-2]
|
165 |
|
|
|
176 |
x = self.linear2(x)
|
177 |
|
178 |
other_tokens = self.other_tokens(
|
179 |
+
torch.tensor([0, 1], device=self.other_tokens.weight.device)
|
180 |
+
.expand(x.shape[0], -1)
|
|
|
181 |
)
|
182 |
assert other_tokens.shape == (
|
183 |
x.shape[0],
|
|
|
202 |
|
203 |
class JoyCaptionModel:
|
204 |
"""
|
205 |
+
A class for generating captions for images using CLIP, LLM,
|
206 |
+
and custom image adapters.
|
207 |
+
|
208 |
+
This class encapsulates the functionality to load and initialize
|
209 |
+
various models (CLIP, LLM, image adapter) and use them to process
|
210 |
+
images and generate captions.
|
211 |
|
|
|
|
|
212 |
It supports different caption types, tones, and lengths.
|
213 |
|
214 |
Attributes:
|
|
|
220 |
Methods:
|
221 |
load_models(): Load and initialize all required models.
|
222 |
process_image(input_image, caption_type, caption_tone, caption_length):
|
223 |
+
Process an input image and generate a caption
|
224 |
+
based on specified parameters.
|
225 |
"""
|
226 |
|
227 |
def __init__(self):
|
|
|
244 |
CHECKPOINT_PATH / "clip_model.pt", map_location="cpu"
|
245 |
)
|
246 |
checkpoint = {
|
247 |
+
k.replace("_orig_mod.module.", ""): v
|
248 |
+
for k, v in checkpoint.items()
|
249 |
}
|
250 |
self.clip_model.load_state_dict(checkpoint)
|
251 |
del checkpoint
|
|
|
255 |
self.clip_model.to("cuda")
|
256 |
|
257 |
print("Loading tokenizer")
|
258 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
259 |
+
MODEL_PATH, use_fast=False
|
260 |
+
)
|
261 |
+
assert isinstance(
|
262 |
+
self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)
|
263 |
+
)
|
264 |
|
265 |
print("Loading LLM")
|
266 |
if (CHECKPOINT_PATH / "text_model").exists():
|
267 |
print("Loading VLM's custom text model")
|
268 |
self.text_model = AutoModelForCausalLM.from_pretrained(
|
269 |
+
CHECKPOINT_PATH / "text_model",
|
270 |
+
device_map=0,
|
271 |
+
torch_dtype=torch.bfloat16
|
272 |
)
|
273 |
else:
|
274 |
self.text_model = AutoModelForCausalLM.from_pretrained(
|
|
|
287 |
False,
|
288 |
)
|
289 |
self.image_adapter.load_state_dict(
|
290 |
+
torch.load(
|
291 |
+
CHECKPOINT_PATH / "image_adapter.pt",
|
292 |
+
map_location="cpu"
|
293 |
+
)
|
294 |
)
|
295 |
self.image_adapter.eval()
|
296 |
self.image_adapter.to("cuda")
|
|
|
305 |
custom_prompt: str | None = None,
|
306 |
) -> str:
|
307 |
"""
|
308 |
+
Process an input image and generate a caption based on specified
|
309 |
+
parameters.
|
310 |
"""
|
311 |
torch.cuda.empty_cache()
|
312 |
|
|
|
326 |
embedded_images, prompt
|
327 |
)
|
328 |
|
329 |
+
generate_ids = self._generate_caption(inputs_embeds,
|
330 |
+
input_ids,
|
331 |
+
attention_mask)
|
332 |
caption = self._decode_caption(generate_ids, input_ids)
|
333 |
|
334 |
return caption.strip()
|
335 |
|
336 |
+
def generate_valid_caption(
|
337 |
+
self,
|
338 |
+
input_image: Image.Image,
|
339 |
+
caption_type: str,
|
340 |
+
caption_tone: str,
|
341 |
+
caption_length: str | int,
|
342 |
+
custom_prompt: str | None = None,
|
343 |
+
) -> str:
|
344 |
+
"""
|
345 |
+
Generate a valid caption, retrying if the caption contains only special
|
346 |
+
characters or does not end with a period, exclamation mark, or
|
347 |
+
question mark.
|
348 |
+
"""
|
349 |
+
while True:
|
350 |
+
caption = self.process_image(
|
351 |
+
input_image, caption_type, caption_tone,
|
352 |
+
caption_length, custom_prompt
|
353 |
+
)
|
354 |
+
# This regex checks if the caption contains at least one word character
|
355 |
+
# and ends with a period, exclamation mark, or question mark.
|
356 |
+
# \w matches any word character (letters, digits, or underscore)
|
357 |
+
# caption[-1] checks the last character of the caption
|
358 |
+
if re.search(r'\w', caption) and caption[-1] in {'.', '!', '?'}:
|
359 |
+
return caption
|
360 |
+
print("Generated caption is invalid. Retrying...")
|
361 |
+
|
362 |
def _get_prompt_string(self, caption_type, caption_tone, caption_length):
|
363 |
length = None if caption_length == "any" else caption_length
|
364 |
|
|
|
449 |
|
450 |
input_ids = torch.cat(
|
451 |
[
|
452 |
+
torch.tensor(
|
453 |
+
[[self.tokenizer.bos_token_id]], dtype=torch.long
|
454 |
+
),
|
455 |
+
torch.zeros(
|
456 |
+
(1, embedded_images.shape[1]), dtype=torch.long
|
457 |
+
),
|
458 |
prompt,
|
459 |
+
torch.tensor(
|
460 |
+
[[self.tokenizer.eos_token_id]], dtype=torch.long
|
461 |
+
),
|
462 |
],
|
463 |
dim=1,
|
464 |
).to("cuda")
|
|
|
478 |
return generate_ids
|
479 |
|
480 |
def _decode_caption(self, generate_ids, input_ids):
|
481 |
+
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
482 |
|
483 |
+
if (generate_ids[0][-1] == self.tokenizer.eos_token_id or
|
484 |
+
generate_ids[0][-1] == self.tokenizer.convert_tokens_to_ids(
|
485 |
+
"<|eot_id|>")):
|
486 |
generate_ids = generate_ids[:, :-1]
|
487 |
|
488 |
caption = self.tokenizer.batch_decode(
|
489 |
+
generate_ids,
|
490 |
+
skip_special_tokens=False,
|
491 |
+
clean_up_tokenization_spaces=False
|
492 |
)[0]
|
493 |
return caption
|
494 |
|
495 |
|
496 |
def main():
|
497 |
+
"""
|
498 |
+
Generate captions for images in a directory
|
499 |
+
and save them as .caption files.
|
500 |
+
"""
|
501 |
parser = argparse.ArgumentParser(
|
502 |
+
description=(
|
503 |
+
"Generate captions for images in a directory and save them as "
|
504 |
+
".caption files."
|
505 |
+
)
|
506 |
)
|
507 |
parser.add_argument(
|
508 |
"directory", type=str, help="Target directory containing images."
|
|
|
522 |
help="Tone of the caption.",
|
523 |
)
|
524 |
parser.add_argument(
|
525 |
+
"--caption_length",
|
526 |
+
type=str,
|
527 |
+
default="any",
|
528 |
+
help="Length of the caption."
|
529 |
)
|
530 |
parser.add_argument(
|
531 |
"--dont-strip-commas",
|
532 |
action="store_true",
|
533 |
+
help=(
|
534 |
+
"If set, commas will not be stripped from the generated captions."
|
535 |
+
),
|
536 |
)
|
537 |
parser.add_argument(
|
538 |
"--custom_prompt",
|
539 |
type=str,
|
540 |
+
help=(
|
541 |
+
"Custom prompt for the captioner. "
|
542 |
+
"Use with --caption_type custom."
|
543 |
+
),
|
544 |
)
|
545 |
parser.add_argument(
|
546 |
"--add-commas-to-sentence-ends",
|
|
|
552 |
type=int,
|
553 |
nargs="?",
|
554 |
const=-1,
|
555 |
+
help=(
|
556 |
+
"Use .txt files with the same base filename "
|
557 |
+
"as the images as input to the captioner. "
|
558 |
+
"Optionally specify the number of tags to use."
|
559 |
+
),
|
560 |
)
|
561 |
parser.add_argument(
|
562 |
"--random-tags",
|
563 |
type=int,
|
564 |
+
help=(
|
565 |
+
"Randomly select n number of tags. "
|
566 |
+
"Only works if --feed-from-tags is enabled."
|
567 |
+
),
|
568 |
)
|
569 |
|
570 |
args = parser.parse_args()
|
571 |
|
572 |
# Validate random-tags usage
|
573 |
if args.random_tags is not None and args.feed_from_tags is None:
|
574 |
+
parser.error(
|
575 |
+
"--random-tags can only be used when --feed-from-tags is enabled"
|
576 |
+
)
|
577 |
|
578 |
print("Loading e621 tag data")
|
579 |
tagset_normalizer = make_tagset_normalizer()
|
|
|
584 |
|
585 |
# Validate custom prompt usage
|
586 |
if args.caption_type == "custom" and not args.custom_prompt:
|
587 |
+
parser.error(
|
588 |
+
"--custom_prompt is required when using --caption_type custom"
|
589 |
+
)
|
590 |
elif args.caption_type != "custom" and args.custom_prompt:
|
591 |
+
parser.error(
|
592 |
+
"--custom_prompt can only be used with --caption_type custom"
|
593 |
+
)
|
594 |
|
595 |
image_extensions = {".webp", ".png", ".jpeg", ".jpg", ".jxl"}
|
596 |
for image_path in Path(args.directory).rglob("*"):
|
|
|
609 |
if args.caption_type == "custom":
|
610 |
custom_prompt = args.custom_prompt
|
611 |
elif args.feed_from_tags is not None:
|
612 |
+
custom_prompt = prompt_from_tags(
|
613 |
+
args, image_path, tagset_normalizer
|
614 |
+
)
|
615 |
|
616 |
print(f"Custom prompt: {custom_prompt}")
|
617 |
|
618 |
+
caption = joy_caption_model.generate_valid_caption(
|
619 |
input_image,
|
620 |
args.caption_type,
|
621 |
args.caption_tone,
|
|
|
697 |
return tagset_normalizer.map_inputs(input_map, on_conflict="ignore")
|
698 |
|
699 |
|
700 |
+
def format_nl_list(word_list):
|
701 |
+
"""
|
702 |
+
Takes a list of words and generates a natural language output.
|
703 |
+
"""
|
704 |
+
n = len(word_list)
|
705 |
assert n > 0
|
706 |
if n == 1:
|
707 |
+
return word_list[0]
|
708 |
+
if n == 2:
|
709 |
+
return f"{word_list[0]} and {word_list[1]}"
|
710 |
+
# n > 2
|
711 |
+
*head, last = word_list
|
712 |
+
return ", ".join(head) + ", and " + last
|
713 |
|
714 |
|
715 |
TAG_SPECIES = tag_category2id["species"]
|
|
|
720 |
TAG_FREQ_THRESH = 0
|
721 |
|
722 |
|
723 |
+
def prompt_from_tags(args, image_path: Path,
|
724 |
+
tagset_normalizer: TagSetNormalizer):
|
725 |
"""
|
726 |
Generates a prompt from tags associated with the given image.
|
727 |
|
728 |
Args:
|
729 |
args: Additional arguments for the function.
|
730 |
+
image_path (Path):
|
731 |
+
The path to the image file.
|
732 |
+
tagset_normalizer (TagSetNormalizer):
|
733 |
+
An instance to normalize the tag set.
|
734 |
|
735 |
Returns:
|
736 |
None
|
|
|
747 |
|
748 |
# These lists contain tuples (freq, tag, tag_id)
|
749 |
tag_by_category: Dict[int, List[Tuple[int, str, int]]] = {
|
750 |
+
cat: []
|
751 |
+
for cat in [TAG_ARTIST, TAG_CHARACTER, TAG_COPYRIGHT, TAG_SPECIES]
|
752 |
}
|
753 |
other_tags: List[Tuple[int, str, int]] = []
|
754 |
implied: set = set()
|
|
|
757 |
# Encode the tag into a numerical id
|
758 |
tag_id = encode(tag.replace(" ", "_"))
|
759 |
if tag_id is None:
|
760 |
+
other_tags.append((0, tag, 0))
|
761 |
+
implied.update(tagset_normalizer.implications_rej.get(0, ()))
|
762 |
continue
|
763 |
# Get the category of the tag
|
764 |
cat_id = tag_id_to_cat_id[tag_id]
|
|
|
770 |
freq = tag_rank_to_freq(tag_id)
|
771 |
if freq < TAG_FREQ_THRESH:
|
772 |
continue
|
773 |
+
tag_by_category.get(cat_id, other_tags).append(
|
774 |
+
(int(freq), tag, tag_id)
|
775 |
+
)
|
776 |
|
777 |
other_tags = sorted(
|
778 |
(int(freq), tag, tag_id)
|
779 |
for freq, tag, tag_id in other_tags
|
780 |
if tag_id not in implied
|
781 |
)
|
782 |
+
|
783 |
for cat_id, cat_list in tag_by_category.items():
|
784 |
tag_by_category[cat_id] = sorted(
|
785 |
(int(freq), tag, tag_id)
|
|
|
792 |
num_tags = min(args.random_tags, len(other_tags))
|
793 |
other_tags = random.sample(
|
794 |
[
|
795 |
+
(i, tag, 0)
|
796 |
+
for i, tag in enumerate(tags[: round(args.random_tags * 1.5)])
|
797 |
],
|
798 |
num_tags,
|
799 |
)
|
|
|
809 |
artist_txt = f"by {format_nl_list(artist_list)}"
|
810 |
else:
|
811 |
artist_txt = ""
|
812 |
+
|
813 |
character_tag = tag_by_category[TAG_CHARACTER]
|
814 |
if character_tag:
|
815 |
tags = [tag for _, tag, _ in character_tag[:4]]
|
816 |
character_txt = f"named {format_nl_list(tags)}"
|
817 |
else:
|
818 |
character_txt = ""
|
819 |
+
|
820 |
species_tag = tag_by_category[TAG_SPECIES]
|
821 |
if species_tag:
|
822 |
+
species_txt = (
|
823 |
+
"of a "
|
824 |
+
if len(character_tag) <= 1 and len(species_tag) <= 1
|
825 |
+
else "of "
|
826 |
+
)
|
827 |
species_txt += format_nl_list([tp[1] for tp in species_tag[:4]])
|
828 |
else:
|
829 |
if character_tag:
|
830 |
species_txt = (
|
831 |
+
" a character" if len(character_tag) <= 1 else " characters"
|
|
|
|
|
832 |
)
|
833 |
else:
|
834 |
species_txt = ""
|
835 |
+
|
836 |
copyright_tag = tag_by_category[TAG_COPYRIGHT]
|
837 |
if copyright_tag:
|
838 |
tags = [tag for _, tag, *_ in copyright_tag[:4]]
|