k4d3 commited on
Commit
56166a4
1 Parent(s): b7d52fe

Code health maintenance + generate_valid_caption

Browse files
Files changed (1) hide show
  1. joy +169 -68
joy CHANGED
@@ -19,6 +19,7 @@ import argparse
19
  import re
20
  import random
21
  from pathlib import Path
 
22
  from PIL import Image
23
  import pillow_jxl
24
  import torch
@@ -32,7 +33,6 @@ from transformers import (
32
  )
33
  from torch import nn
34
  from e6db_reader import TagSetNormalizer, tag_category2id, tag_rank_to_freq
35
- from typing import List, Tuple, Dict
36
 
37
  CLIP_PATH = "google/siglip-so400m-patch14-384"
38
  MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
@@ -62,7 +62,8 @@ CAPTION_TYPE_MAP = {
62
  "Write a stable diffusion prompt for this image."
63
  ],
64
  ("training_prompt", "formal", False, True): [
65
- "Write a stable diffusion prompt for this image within {word_count} " "words."
 
66
  ],
67
  ("training_prompt", "formal", True, False): [
68
  "Write a {length} stable diffusion prompt for this image."
@@ -90,12 +91,18 @@ class ImageAdapter(nn.Module):
90
  embeddings, and deep feature extraction.
91
 
92
  Args:
93
- input_features (int): Number of input features from the vision model.
94
- output_features (int): Number of output features to match the text model.
95
- ln1 (bool): Whether to use layer normalization.
96
- pos_emb (bool): Whether to use positional embeddings.
97
- num_image_tokens (int): Number of image tokens.
98
- deep_extract (bool): Whether to use deep feature extraction.
 
 
 
 
 
 
99
  """
100
 
101
  def __init__(
@@ -131,7 +138,8 @@ class ImageAdapter(nn.Module):
131
  Forward pass of the image adapter.
132
 
133
  Args:
134
- vision_outputs (torch.Tensor): Output tensor from the CLIP vision model.
 
135
 
136
  Returns:
137
  torch.Tensor: Adapted image features.
@@ -148,9 +156,10 @@ class ImageAdapter(nn.Module):
148
  dim=-1,
149
  )
150
  assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}"
 
151
  assert (
152
- x.shape[-1] == vision_outputs[-2].shape[-1] * 5
153
- ), f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
154
  else:
155
  x = vision_outputs[-2]
156
 
@@ -167,9 +176,8 @@ class ImageAdapter(nn.Module):
167
  x = self.linear2(x)
168
 
169
  other_tokens = self.other_tokens(
170
- torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(
171
- x.shape[0], -1
172
- )
173
  )
174
  assert other_tokens.shape == (
175
  x.shape[0],
@@ -194,10 +202,13 @@ class ImageAdapter(nn.Module):
194
 
195
  class JoyCaptionModel:
196
  """
197
- A class for generating captions for images using CLIP, LLM, and custom image adapters.
 
 
 
 
 
198
 
199
- This class encapsulates the functionality to load and initialize various models
200
- (CLIP, LLM, image adapter) and use them to process images and generate captions.
201
  It supports different caption types, tones, and lengths.
202
 
203
  Attributes:
@@ -209,7 +220,8 @@ class JoyCaptionModel:
209
  Methods:
210
  load_models(): Load and initialize all required models.
211
  process_image(input_image, caption_type, caption_tone, caption_length):
212
- Process an input image and generate a caption based on specified parameters.
 
213
  """
214
 
215
  def __init__(self):
@@ -232,7 +244,8 @@ class JoyCaptionModel:
232
  CHECKPOINT_PATH / "clip_model.pt", map_location="cpu"
233
  )
234
  checkpoint = {
235
- k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()
 
236
  }
237
  self.clip_model.load_state_dict(checkpoint)
238
  del checkpoint
@@ -242,16 +255,20 @@ class JoyCaptionModel:
242
  self.clip_model.to("cuda")
243
 
244
  print("Loading tokenizer")
245
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
246
- assert isinstance(self.tokenizer, PreTrainedTokenizer) or isinstance(
247
- self.tokenizer, PreTrainedTokenizerFast
248
- ), f"Tokenizer is of type {type(self.tokenizer)}"
 
 
249
 
250
  print("Loading LLM")
251
  if (CHECKPOINT_PATH / "text_model").exists():
252
  print("Loading VLM's custom text model")
253
  self.text_model = AutoModelForCausalLM.from_pretrained(
254
- CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16
 
 
255
  )
256
  else:
257
  self.text_model = AutoModelForCausalLM.from_pretrained(
@@ -270,7 +287,10 @@ class JoyCaptionModel:
270
  False,
271
  )
272
  self.image_adapter.load_state_dict(
273
- torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu")
 
 
 
274
  )
275
  self.image_adapter.eval()
276
  self.image_adapter.to("cuda")
@@ -285,7 +305,8 @@ class JoyCaptionModel:
285
  custom_prompt: str | None = None,
286
  ) -> str:
287
  """
288
- Process an input image and generate a caption based on specified parameters.
 
289
  """
290
  torch.cuda.empty_cache()
291
 
@@ -305,11 +326,39 @@ class JoyCaptionModel:
305
  embedded_images, prompt
306
  )
307
 
308
- generate_ids = self._generate_caption(inputs_embeds, input_ids, attention_mask)
 
 
309
  caption = self._decode_caption(generate_ids, input_ids)
310
 
311
  return caption.strip()
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  def _get_prompt_string(self, caption_type, caption_tone, caption_length):
314
  length = None if caption_length == "any" else caption_length
315
 
@@ -400,10 +449,16 @@ class JoyCaptionModel:
400
 
401
  input_ids = torch.cat(
402
  [
403
- torch.tensor([[self.tokenizer.bos_token_id]], dtype=torch.long),
404
- torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
 
 
 
 
405
  prompt,
406
- torch.tensor([[self.tokenizer.eos_token_id]], dtype=torch.long),
 
 
407
  ],
408
  dim=1,
409
  ).to("cuda")
@@ -423,23 +478,31 @@ class JoyCaptionModel:
423
  return generate_ids
424
 
425
  def _decode_caption(self, generate_ids, input_ids):
426
- generate_ids = generate_ids[:, input_ids.shape[1] :]
427
 
428
- if generate_ids[0][-1] == self.tokenizer.eos_token_id or generate_ids[0][
429
- -1
430
- ] == self.tokenizer.convert_tokens_to_ids("<|eot_id|>"):
431
  generate_ids = generate_ids[:, :-1]
432
 
433
  caption = self.tokenizer.batch_decode(
434
- generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
 
 
435
  )[0]
436
  return caption
437
 
438
 
439
  def main():
440
- """Generate captions for images in a directory and save them as .caption files."""
 
 
 
441
  parser = argparse.ArgumentParser(
442
- description="Generate captions for images in a directory and save them as .caption files."
 
 
 
443
  )
444
  parser.add_argument(
445
  "directory", type=str, help="Target directory containing images."
@@ -459,17 +522,25 @@ def main():
459
  help="Tone of the caption.",
460
  )
461
  parser.add_argument(
462
- "--caption_length", type=str, default="any", help="Length of the caption."
 
 
 
463
  )
464
  parser.add_argument(
465
  "--dont-strip-commas",
466
  action="store_true",
467
- help="If set, commas will not be stripped from the generated captions.",
 
 
468
  )
469
  parser.add_argument(
470
  "--custom_prompt",
471
  type=str,
472
- help="Custom prompt for the captioner. Use with --caption_type custom.",
 
 
 
473
  )
474
  parser.add_argument(
475
  "--add-commas-to-sentence-ends",
@@ -481,19 +552,28 @@ def main():
481
  type=int,
482
  nargs="?",
483
  const=-1,
484
- help="Use .txt files with the same base filename as the images as input to the captioner. Optionally specify the number of tags to use.",
 
 
 
 
485
  )
486
  parser.add_argument(
487
  "--random-tags",
488
  type=int,
489
- help="Randomly select n number of tags. Only works if --feed-from-tags is enabled.",
 
 
 
490
  )
491
 
492
  args = parser.parse_args()
493
 
494
  # Validate random-tags usage
495
  if args.random_tags is not None and args.feed_from_tags is None:
496
- parser.error("--random-tags can only be used when --feed-from-tags is enabled")
 
 
497
 
498
  print("Loading e621 tag data")
499
  tagset_normalizer = make_tagset_normalizer()
@@ -504,9 +584,13 @@ def main():
504
 
505
  # Validate custom prompt usage
506
  if args.caption_type == "custom" and not args.custom_prompt:
507
- parser.error("--custom_prompt is required when using --caption_type custom")
 
 
508
  elif args.caption_type != "custom" and args.custom_prompt:
509
- parser.error("--custom_prompt can only be used with --caption_type custom")
 
 
510
 
511
  image_extensions = {".webp", ".png", ".jpeg", ".jpg", ".jxl"}
512
  for image_path in Path(args.directory).rglob("*"):
@@ -525,11 +609,13 @@ def main():
525
  if args.caption_type == "custom":
526
  custom_prompt = args.custom_prompt
527
  elif args.feed_from_tags is not None:
528
- custom_prompt = prompt_from_tags(args, image_path, tagset_normalizer)
 
 
529
 
530
  print(f"Custom prompt: {custom_prompt}")
531
 
532
- caption = joy_caption_model.process_image(
533
  input_image,
534
  args.caption_type,
535
  args.caption_tone,
@@ -611,16 +697,19 @@ def make_tagset_normalizer():
611
  return tagset_normalizer.map_inputs(input_map, on_conflict="ignore")
612
 
613
 
614
- def format_nl_list(l):
615
- n = len(l)
 
 
 
616
  assert n > 0
617
  if n == 1:
618
- return l[0]
619
- elif n == 2:
620
- return f"{l[0]} and {l[1]}"
621
- else: # n > 2
622
- *head, last = l
623
- return ", ".join(head) + ", and " + last
624
 
625
 
626
  TAG_SPECIES = tag_category2id["species"]
@@ -631,14 +720,17 @@ TAG_META = tag_category2id["meta"]
631
  TAG_FREQ_THRESH = 0
632
 
633
 
634
- def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer):
 
635
  """
636
  Generates a prompt from tags associated with the given image.
637
 
638
  Args:
639
  args: Additional arguments for the function.
640
- image_path (Path): The path to the image file.
641
- tagset_normalizer (TagSetNormalizer): An instance to normalize the tag set.
 
 
642
 
643
  Returns:
644
  None
@@ -655,7 +747,8 @@ def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer
655
 
656
  # These lists contain tuples (freq, tag, tag_id)
657
  tag_by_category: Dict[int, List[Tuple[int, str, int]]] = {
658
- cat: [] for cat in [TAG_ARTIST, TAG_CHARACTER, TAG_COPYRIGHT, TAG_SPECIES]
 
659
  }
660
  other_tags: List[Tuple[int, str, int]] = []
661
  implied: set = set()
@@ -664,8 +757,8 @@ def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer
664
  # Encode the tag into a numerical id
665
  tag_id = encode(tag.replace(" ", "_"))
666
  if tag_id is None:
667
- other_tags.append((0, tag, None))
668
- implied.update(tagset_normalizer.implications_rej.get(tag_id, ()))
669
  continue
670
  # Get the category of the tag
671
  cat_id = tag_id_to_cat_id[tag_id]
@@ -677,13 +770,16 @@ def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer
677
  freq = tag_rank_to_freq(tag_id)
678
  if freq < TAG_FREQ_THRESH:
679
  continue
680
- tag_by_category.get(cat_id, other_tags).append((int(freq), tag, tag_id))
 
 
681
 
682
  other_tags = sorted(
683
  (int(freq), tag, tag_id)
684
  for freq, tag, tag_id in other_tags
685
  if tag_id not in implied
686
  )
 
687
  for cat_id, cat_list in tag_by_category.items():
688
  tag_by_category[cat_id] = sorted(
689
  (int(freq), tag, tag_id)
@@ -696,8 +792,8 @@ def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer
696
  num_tags = min(args.random_tags, len(other_tags))
697
  other_tags = random.sample(
698
  [
699
- (i, tag, tag_id)
700
- for i, tag, tag_id in enumerate(tags[: round(args.random_tags * 1.5)])
701
  ],
702
  num_tags,
703
  )
@@ -713,25 +809,30 @@ def prompt_from_tags(args, image_path: Path, tagset_normalizer: TagSetNormalizer
713
  artist_txt = f"by {format_nl_list(artist_list)}"
714
  else:
715
  artist_txt = ""
 
716
  character_tag = tag_by_category[TAG_CHARACTER]
717
  if character_tag:
718
  tags = [tag for _, tag, _ in character_tag[:4]]
719
  character_txt = f"named {format_nl_list(tags)}"
720
  else:
721
  character_txt = ""
 
722
  species_tag = tag_by_category[TAG_SPECIES]
723
  if species_tag:
724
- species_txt = "of a " if len(character_tag) <= 1 and len(species_tag) <= 1 else "of "
 
 
 
 
725
  species_txt += format_nl_list([tp[1] for tp in species_tag[:4]])
726
  else:
727
  if character_tag:
728
  species_txt = (
729
- " a character"
730
- if len(character_tag) <= 1
731
- else " characters"
732
  )
733
  else:
734
  species_txt = ""
 
735
  copyright_tag = tag_by_category[TAG_COPYRIGHT]
736
  if copyright_tag:
737
  tags = [tag for _, tag, *_ in copyright_tag[:4]]
 
19
  import re
20
  import random
21
  from pathlib import Path
22
+ from typing import List, Tuple, Dict
23
  from PIL import Image
24
  import pillow_jxl
25
  import torch
 
33
  )
34
  from torch import nn
35
  from e6db_reader import TagSetNormalizer, tag_category2id, tag_rank_to_freq
 
36
 
37
  CLIP_PATH = "google/siglip-so400m-patch14-384"
38
  MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
 
62
  "Write a stable diffusion prompt for this image."
63
  ],
64
  ("training_prompt", "formal", False, True): [
65
+ "Write a stable diffusion prompt for this image within " +
66
+ "{word_count} words."
67
  ],
68
  ("training_prompt", "formal", True, False): [
69
  "Write a {length} stable diffusion prompt for this image."
 
91
  embeddings, and deep feature extraction.
92
 
93
  Args:
94
+ input_features (int):
95
+ Number of input features from the vision model.
96
+ output_features (int):
97
+ Number of output features to match the text model.
98
+ ln1 (bool):
99
+ Whether to use layer normalization.
100
+ pos_emb (bool):
101
+ Whether to use positional embeddings.
102
+ num_image_tokens (int):
103
+ Number of image tokens.
104
+ deep_extract (bool):
105
+ Whether to use deep feature extraction.
106
  """
107
 
108
  def __init__(
 
138
  Forward pass of the image adapter.
139
 
140
  Args:
141
+ vision_outputs (torch.Tensor):
142
+ Output tensor from the CLIP vision model.
143
 
144
  Returns:
145
  torch.Tensor: Adapted image features.
 
156
  dim=-1,
157
  )
158
  assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}"
159
+ expected_shape = vision_outputs[-2].shape[-1] * 5
160
  assert (
161
+ x.shape[-1] == expected_shape
162
+ ), f"Expected {expected_shape}, got {x.shape[-1]}"
163
  else:
164
  x = vision_outputs[-2]
165
 
 
176
  x = self.linear2(x)
177
 
178
  other_tokens = self.other_tokens(
179
+ torch.tensor([0, 1], device=self.other_tokens.weight.device)
180
+ .expand(x.shape[0], -1)
 
181
  )
182
  assert other_tokens.shape == (
183
  x.shape[0],
 
202
 
203
  class JoyCaptionModel:
204
  """
205
+ A class for generating captions for images using CLIP, LLM,
206
+ and custom image adapters.
207
+
208
+ This class encapsulates the functionality to load and initialize
209
+ various models (CLIP, LLM, image adapter) and use them to process
210
+ images and generate captions.
211
 
 
 
212
  It supports different caption types, tones, and lengths.
213
 
214
  Attributes:
 
220
  Methods:
221
  load_models(): Load and initialize all required models.
222
  process_image(input_image, caption_type, caption_tone, caption_length):
223
+ Process an input image and generate a caption
224
+ based on specified parameters.
225
  """
226
 
227
  def __init__(self):
 
244
  CHECKPOINT_PATH / "clip_model.pt", map_location="cpu"
245
  )
246
  checkpoint = {
247
+ k.replace("_orig_mod.module.", ""): v
248
+ for k, v in checkpoint.items()
249
  }
250
  self.clip_model.load_state_dict(checkpoint)
251
  del checkpoint
 
255
  self.clip_model.to("cuda")
256
 
257
  print("Loading tokenizer")
258
+ self.tokenizer = AutoTokenizer.from_pretrained(
259
+ MODEL_PATH, use_fast=False
260
+ )
261
+ assert isinstance(
262
+ self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)
263
+ )
264
 
265
  print("Loading LLM")
266
  if (CHECKPOINT_PATH / "text_model").exists():
267
  print("Loading VLM's custom text model")
268
  self.text_model = AutoModelForCausalLM.from_pretrained(
269
+ CHECKPOINT_PATH / "text_model",
270
+ device_map=0,
271
+ torch_dtype=torch.bfloat16
272
  )
273
  else:
274
  self.text_model = AutoModelForCausalLM.from_pretrained(
 
287
  False,
288
  )
289
  self.image_adapter.load_state_dict(
290
+ torch.load(
291
+ CHECKPOINT_PATH / "image_adapter.pt",
292
+ map_location="cpu"
293
+ )
294
  )
295
  self.image_adapter.eval()
296
  self.image_adapter.to("cuda")
 
305
  custom_prompt: str | None = None,
306
  ) -> str:
307
  """
308
+ Process an input image and generate a caption based on specified
309
+ parameters.
310
  """
311
  torch.cuda.empty_cache()
312
 
 
326
  embedded_images, prompt
327
  )
328
 
329
+ generate_ids = self._generate_caption(inputs_embeds,
330
+ input_ids,
331
+ attention_mask)
332
  caption = self._decode_caption(generate_ids, input_ids)
333
 
334
  return caption.strip()
335
 
336
+ def generate_valid_caption(
337
+ self,
338
+ input_image: Image.Image,
339
+ caption_type: str,
340
+ caption_tone: str,
341
+ caption_length: str | int,
342
+ custom_prompt: str | None = None,
343
+ ) -> str:
344
+ """
345
+ Generate a valid caption, retrying if the caption contains only special
346
+ characters or does not end with a period, exclamation mark, or
347
+ question mark.
348
+ """
349
+ while True:
350
+ caption = self.process_image(
351
+ input_image, caption_type, caption_tone,
352
+ caption_length, custom_prompt
353
+ )
354
+ # This regex checks if the caption contains at least one word character
355
+ # and ends with a period, exclamation mark, or question mark.
356
+ # \w matches any word character (letters, digits, or underscore)
357
+ # caption[-1] checks the last character of the caption
358
+ if re.search(r'\w', caption) and caption[-1] in {'.', '!', '?'}:
359
+ return caption
360
+ print("Generated caption is invalid. Retrying...")
361
+
362
  def _get_prompt_string(self, caption_type, caption_tone, caption_length):
363
  length = None if caption_length == "any" else caption_length
364
 
 
449
 
450
  input_ids = torch.cat(
451
  [
452
+ torch.tensor(
453
+ [[self.tokenizer.bos_token_id]], dtype=torch.long
454
+ ),
455
+ torch.zeros(
456
+ (1, embedded_images.shape[1]), dtype=torch.long
457
+ ),
458
  prompt,
459
+ torch.tensor(
460
+ [[self.tokenizer.eos_token_id]], dtype=torch.long
461
+ ),
462
  ],
463
  dim=1,
464
  ).to("cuda")
 
478
  return generate_ids
479
 
480
  def _decode_caption(self, generate_ids, input_ids):
481
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
482
 
483
+ if (generate_ids[0][-1] == self.tokenizer.eos_token_id or
484
+ generate_ids[0][-1] == self.tokenizer.convert_tokens_to_ids(
485
+ "<|eot_id|>")):
486
  generate_ids = generate_ids[:, :-1]
487
 
488
  caption = self.tokenizer.batch_decode(
489
+ generate_ids,
490
+ skip_special_tokens=False,
491
+ clean_up_tokenization_spaces=False
492
  )[0]
493
  return caption
494
 
495
 
496
  def main():
497
+ """
498
+ Generate captions for images in a directory
499
+ and save them as .caption files.
500
+ """
501
  parser = argparse.ArgumentParser(
502
+ description=(
503
+ "Generate captions for images in a directory and save them as "
504
+ ".caption files."
505
+ )
506
  )
507
  parser.add_argument(
508
  "directory", type=str, help="Target directory containing images."
 
522
  help="Tone of the caption.",
523
  )
524
  parser.add_argument(
525
+ "--caption_length",
526
+ type=str,
527
+ default="any",
528
+ help="Length of the caption."
529
  )
530
  parser.add_argument(
531
  "--dont-strip-commas",
532
  action="store_true",
533
+ help=(
534
+ "If set, commas will not be stripped from the generated captions."
535
+ ),
536
  )
537
  parser.add_argument(
538
  "--custom_prompt",
539
  type=str,
540
+ help=(
541
+ "Custom prompt for the captioner. "
542
+ "Use with --caption_type custom."
543
+ ),
544
  )
545
  parser.add_argument(
546
  "--add-commas-to-sentence-ends",
 
552
  type=int,
553
  nargs="?",
554
  const=-1,
555
+ help=(
556
+ "Use .txt files with the same base filename "
557
+ "as the images as input to the captioner. "
558
+ "Optionally specify the number of tags to use."
559
+ ),
560
  )
561
  parser.add_argument(
562
  "--random-tags",
563
  type=int,
564
+ help=(
565
+ "Randomly select n number of tags. "
566
+ "Only works if --feed-from-tags is enabled."
567
+ ),
568
  )
569
 
570
  args = parser.parse_args()
571
 
572
  # Validate random-tags usage
573
  if args.random_tags is not None and args.feed_from_tags is None:
574
+ parser.error(
575
+ "--random-tags can only be used when --feed-from-tags is enabled"
576
+ )
577
 
578
  print("Loading e621 tag data")
579
  tagset_normalizer = make_tagset_normalizer()
 
584
 
585
  # Validate custom prompt usage
586
  if args.caption_type == "custom" and not args.custom_prompt:
587
+ parser.error(
588
+ "--custom_prompt is required when using --caption_type custom"
589
+ )
590
  elif args.caption_type != "custom" and args.custom_prompt:
591
+ parser.error(
592
+ "--custom_prompt can only be used with --caption_type custom"
593
+ )
594
 
595
  image_extensions = {".webp", ".png", ".jpeg", ".jpg", ".jxl"}
596
  for image_path in Path(args.directory).rglob("*"):
 
609
  if args.caption_type == "custom":
610
  custom_prompt = args.custom_prompt
611
  elif args.feed_from_tags is not None:
612
+ custom_prompt = prompt_from_tags(
613
+ args, image_path, tagset_normalizer
614
+ )
615
 
616
  print(f"Custom prompt: {custom_prompt}")
617
 
618
+ caption = joy_caption_model.generate_valid_caption(
619
  input_image,
620
  args.caption_type,
621
  args.caption_tone,
 
697
  return tagset_normalizer.map_inputs(input_map, on_conflict="ignore")
698
 
699
 
700
+ def format_nl_list(word_list):
701
+ """
702
+ Takes a list of words and generates a natural language output.
703
+ """
704
+ n = len(word_list)
705
  assert n > 0
706
  if n == 1:
707
+ return word_list[0]
708
+ if n == 2:
709
+ return f"{word_list[0]} and {word_list[1]}"
710
+ # n > 2
711
+ *head, last = word_list
712
+ return ", ".join(head) + ", and " + last
713
 
714
 
715
  TAG_SPECIES = tag_category2id["species"]
 
720
  TAG_FREQ_THRESH = 0
721
 
722
 
723
+ def prompt_from_tags(args, image_path: Path,
724
+ tagset_normalizer: TagSetNormalizer):
725
  """
726
  Generates a prompt from tags associated with the given image.
727
 
728
  Args:
729
  args: Additional arguments for the function.
730
+ image_path (Path):
731
+ The path to the image file.
732
+ tagset_normalizer (TagSetNormalizer):
733
+ An instance to normalize the tag set.
734
 
735
  Returns:
736
  None
 
747
 
748
  # These lists contain tuples (freq, tag, tag_id)
749
  tag_by_category: Dict[int, List[Tuple[int, str, int]]] = {
750
+ cat: []
751
+ for cat in [TAG_ARTIST, TAG_CHARACTER, TAG_COPYRIGHT, TAG_SPECIES]
752
  }
753
  other_tags: List[Tuple[int, str, int]] = []
754
  implied: set = set()
 
757
  # Encode the tag into a numerical id
758
  tag_id = encode(tag.replace(" ", "_"))
759
  if tag_id is None:
760
+ other_tags.append((0, tag, 0))
761
+ implied.update(tagset_normalizer.implications_rej.get(0, ()))
762
  continue
763
  # Get the category of the tag
764
  cat_id = tag_id_to_cat_id[tag_id]
 
770
  freq = tag_rank_to_freq(tag_id)
771
  if freq < TAG_FREQ_THRESH:
772
  continue
773
+ tag_by_category.get(cat_id, other_tags).append(
774
+ (int(freq), tag, tag_id)
775
+ )
776
 
777
  other_tags = sorted(
778
  (int(freq), tag, tag_id)
779
  for freq, tag, tag_id in other_tags
780
  if tag_id not in implied
781
  )
782
+
783
  for cat_id, cat_list in tag_by_category.items():
784
  tag_by_category[cat_id] = sorted(
785
  (int(freq), tag, tag_id)
 
792
  num_tags = min(args.random_tags, len(other_tags))
793
  other_tags = random.sample(
794
  [
795
+ (i, tag, 0)
796
+ for i, tag in enumerate(tags[: round(args.random_tags * 1.5)])
797
  ],
798
  num_tags,
799
  )
 
809
  artist_txt = f"by {format_nl_list(artist_list)}"
810
  else:
811
  artist_txt = ""
812
+
813
  character_tag = tag_by_category[TAG_CHARACTER]
814
  if character_tag:
815
  tags = [tag for _, tag, _ in character_tag[:4]]
816
  character_txt = f"named {format_nl_list(tags)}"
817
  else:
818
  character_txt = ""
819
+
820
  species_tag = tag_by_category[TAG_SPECIES]
821
  if species_tag:
822
+ species_txt = (
823
+ "of a "
824
+ if len(character_tag) <= 1 and len(species_tag) <= 1
825
+ else "of "
826
+ )
827
  species_txt += format_nl_list([tp[1] for tp in species_tag[:4]])
828
  else:
829
  if character_tag:
830
  species_txt = (
831
+ " a character" if len(character_tag) <= 1 else " characters"
 
 
832
  )
833
  else:
834
  species_txt = ""
835
+
836
  copyright_tag = tag_by_category[TAG_COPYRIGHT]
837
  if copyright_tag:
838
  tags = [tag for _, tag, *_ in copyright_tag[:4]]