k4d3 commited on
Commit
ee41534
1 Parent(s): f8b95db

joy updates (we havent tested yet) and remove_extra_whitespace

Browse files
Files changed (2) hide show
  1. joy +143 -24
  2. remove_extra_whitespace +60 -0
joy CHANGED
@@ -18,6 +18,7 @@ import os
18
  import argparse
19
  import re
20
  import random
 
21
  from pathlib import Path
22
  from typing import List, Tuple, Dict
23
  from PIL import Image
@@ -199,6 +200,33 @@ class ImageAdapter(nn.Module):
199
  torch.tensor([2], device=self.other_tokens.weight.device)
200
  ).squeeze(0)
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  class JoyCaptionModel:
204
  """
@@ -302,11 +330,20 @@ class JoyCaptionModel:
302
  caption_type: str,
303
  caption_tone: str,
304
  caption_length: str | int,
305
- custom_prompt: str | None = None,
306
- ) -> str:
307
  """
308
- Process an input image and generate a caption based on specified
309
- parameters.
 
 
 
 
 
 
 
 
 
310
  """
311
  torch.cuda.empty_cache()
312
 
@@ -326,12 +363,18 @@ class JoyCaptionModel:
326
  embedded_images, prompt
327
  )
328
 
329
- generate_ids = self._generate_caption(inputs_embeds,
330
- input_ids,
331
- attention_mask)
332
  caption = self._decode_caption(generate_ids, input_ids)
333
 
334
- return caption.strip()
 
 
 
 
 
 
 
 
335
 
336
  def generate_valid_caption(
337
  self,
@@ -340,29 +383,62 @@ class JoyCaptionModel:
340
  caption_tone: str,
341
  caption_length: str | int,
342
  custom_prompt: str | None = None,
 
 
 
 
 
 
 
343
  ) -> str:
344
  """
345
- Generate a valid caption, retrying if the caption contains only special
346
- characters, does not end with a period, exclamation mark, or question
347
- mark, contains the word fluffy more than once, repeats any word longer
348
- than 4 characters multiple times, or contains only one sentence.
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  """
350
  while True:
351
- caption = self.process_image(
352
  input_image, caption_type, caption_tone,
353
  caption_length, custom_prompt
354
  )
355
- words = re.findall(r'\b\w{5,}\b', caption.lower())
356
- word_counts = {word: words.count(word) for word in set(words)}
357
  sentence_count = len(re.findall(r'[.!?]', caption))
358
-
359
- if (re.search(r'\w', caption) and
360
- caption[-1] in {'.', '!', '?'} and
361
- caption.lower().count('fluffy') <= 1 and
362
- all(count == 1 for count in word_counts.values()) and
363
- sentence_count > 1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  return caption
365
- print(f"Generated caption is invalid. Retrying...\nCaption: {caption!r}")
366
 
367
  def _get_prompt_string(self, caption_type, caption_tone, caption_length):
368
  length = None if caption_length == "any" else caption_length
@@ -498,6 +574,49 @@ class JoyCaptionModel:
498
  )[0]
499
  return caption
500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
 
502
  def main():
503
  """
@@ -619,7 +738,7 @@ def main():
619
  args, image_path, tagset_normalizer
620
  )
621
 
622
- print(f"Custom prompt: {custom_prompt}")
623
 
624
  caption = joy_caption_model.generate_valid_caption(
625
  input_image,
@@ -849,7 +968,7 @@ def prompt_from_tags(args, image_path: Path,
849
  custom_prompt = ' '.join(s for s in [
850
  "Write a descriptive caption for this image",
851
  artist_txt, species_txt, character_txt, copyright_txt,
852
- "in a formal tone. Use these tags to construct your caption:",
853
  tag_string,
854
  ] if s)
855
  return custom_prompt
 
18
  import argparse
19
  import re
20
  import random
21
+ import math
22
  from pathlib import Path
23
  from typing import List, Tuple, Dict
24
  from PIL import Image
 
200
  torch.tensor([2], device=self.other_tokens.weight.device)
201
  ).squeeze(0)
202
 
203
+ STOP_WORDS: set[str] = {
204
+ "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
205
+ "of", "with", "by", "from", "up", "down", "is", "are", "was", "were",
206
+ "be", "been", "being", "have", "has", "had", "do", "does", "did",
207
+ "will", "would", "shall", "should", "can", "could", "may", "might",
208
+ "must", "ought", "i", "you", "he", "she", "it", "we", "they", "them",
209
+ "their", "this", "that", "these", "those", "am", "is", "are", "was",
210
+ "were", "be", "been", "being", "have", "has", "had", "do", "does",
211
+ "did", "will", "would", "shall", "should", "can", "could", "may",
212
+ "might", "must", "ought", "i'm", "you're", "he's", "she's", "it's",
213
+ "we're", "they're", "i've", "you've", "we've", "they've", "i'd",
214
+ "you'd", "he'd", "she'd", "we'd", "they'd", "i'll", "you'll",
215
+ "he'll", "she'll", "we'll", "they'll", "isn't", "aren't", "wasn't",
216
+ "weren't", "hasn't", "haven't", "hadn't", "doesn't", "don't",
217
+ "didn't", "won't", "wouldn't", "shan't", "shouldn't", "can't",
218
+ "cannot", "couldn't", "mustn't", "let's", "that's", "who's",
219
+ "what's", "here's", "there's", "when's", "where's", "why's", "how's",
220
+ "a", "an", "the", "and", "but", "if", "or", "because", "as", "until",
221
+ "while", "of", "at", "by", "for", "with", "about", "against",
222
+ "between", "into", "through", "during", "before", "after", "above",
223
+ "below", "to", "from", "up", "down", "in", "out", "on", "off", "over",
224
+ "under", "again", "further", "then", "once", "here", "there", "when",
225
+ "where", "why", "how", "all", "any", "both", "each", "few", "more",
226
+ "most", "other", "some", "such", "no", "nor", "not", "only", "own",
227
+ "same", "so", "than", "too", "very"
228
+ }
229
+
230
 
231
  class JoyCaptionModel:
232
  """
 
330
  caption_type: str,
331
  caption_tone: str,
332
  caption_length: str | int,
333
+ custom_prompt: str | None = None
334
+ ) -> Tuple[str, float, float]:
335
  """
336
+ Process the input image and generate a caption.
337
+
338
+ Args:
339
+ input_image (Image.Image): The input image to caption.
340
+ caption_type (str): The type of caption to generate.
341
+ caption_tone (str): The tone of the caption.
342
+ caption_length (str | int): The desired length of the caption.
343
+ custom_prompt (str | None): A custom prompt for caption generation.
344
+
345
+ Returns:
346
+ Tuple[str, float, float]: A tuple containing the generated caption, its entropy, and its perplexity.
347
  """
348
  torch.cuda.empty_cache()
349
 
 
363
  embedded_images, prompt
364
  )
365
 
366
+ generate_ids = self._generate_caption(inputs_embeds, input_ids, attention_mask)
 
 
367
  caption = self._decode_caption(generate_ids, input_ids)
368
 
369
+ # Calculate entropy
370
+ token_ids = generate_ids[0].tolist()
371
+ entropy = self._calculate_entropy(token_ids)
372
+
373
+ # Calculate perplexity
374
+ loss = self._calculate_perplexity(generate_ids, input_ids)
375
+ perplexity = math.exp(-loss)
376
+
377
+ return caption.strip(), entropy, perplexity
378
 
379
  def generate_valid_caption(
380
  self,
 
383
  caption_tone: str,
384
  caption_length: str | int,
385
  custom_prompt: str | None = None,
386
+ *,
387
+ limited_words: Dict[str, int] = {"fluffy": 2},
388
+ min_sentence_count: int = 3,
389
+ max_word_repetitions: int = 5,
390
+ min_entropy: float = 1.75,
391
+ max_perplexity: float = 100.0,
392
+ stop_words: set[str] = STOP_WORDS
393
  ) -> str:
394
  """
395
+ Generate a valid caption, retrying if certain conditions are not met.
396
+
397
+ Args:
398
+ input_image (Image.Image): The input image to caption.
399
+ caption_type (str): The type of caption to generate.
400
+ caption_tone (str): The tone of the caption.
401
+ caption_length (str | int): The desired length of the caption.
402
+ custom_prompt (str | None): A custom prompt for caption generation.
403
+ limited_words (Dict[str, int]): Dictionary of words with their maximum allowed occurrences. Default is {"fluffy": 2}.
404
+ min_sentence_count (int): Minimum required number of sentences. Default is 3.
405
+ max_word_repetitions (int): Maximum allowed repetitions for words longer than 4 characters. Default is 5.
406
+ min_entropy (float): Minimum required entropy of the caption. Default is 1.75.
407
+ max_perplexity (float): Maximum allowed perplexity of the caption. Default is 100.0.
408
+ stop_words (set[str]): Set of stop words to exclude from repetition checks. Default is STOP_WORDS.
409
+
410
+ Returns:
411
+ str: A valid caption meeting all specified criteria.
412
  """
413
  while True:
414
+ caption, entropy, perplexity = self.process_image(
415
  input_image, caption_type, caption_tone,
416
  caption_length, custom_prompt
417
  )
418
+ words = re.findall(r'\b\w+\b', caption.lower())
419
+ word_counts = {word: words.count(word) for word in set(words) if word not in stop_words}
420
  sentence_count = len(re.findall(r'[.!?]', caption))
421
+
422
+ if not re.search(r'\w', caption):
423
+ print(f"Retrying: Caption contains only special characters.\nCaption: {caption!r}")
424
+ elif caption[-1] not in {'.', '!', '?'}:
425
+ print(f"Retrying: Caption does not end with proper punctuation.\nCaption: {caption!r}")
426
+ elif any(caption.lower().count(word) > max_count for word, max_count in limited_words.items()):
427
+ exceeded_words = [f"{word} ({caption.lower().count(word)}/{max_count})"
428
+ for word, max_count in limited_words.items()
429
+ if caption.lower().count(word) > max_count]
430
+ print(f"Retrying: Limited words exceeded: {', '.join(exceeded_words)}.\nCaption: {caption!r}")
431
+ elif any(count > max_word_repetitions for word, count in word_counts.items() if len(word) > 4):
432
+ repeated_words = [word for word, count in word_counts.items() if count > max_word_repetitions and len(word) > 4]
433
+ print(f"Retrying: Words repeated more than {max_word_repetitions} times: {', '.join(repeated_words)}.\nCaption: {caption!r}")
434
+ elif sentence_count < min_sentence_count:
435
+ print(f"Retrying: Only {sentence_count} sentences (min: {min_sentence_count}).\nCaption: {caption!r}")
436
+ elif entropy < min_entropy:
437
+ print(f"Retrying: Low entropy ({entropy:.2f} < {min_entropy}).\nCaption: {caption!r}")
438
+ elif perplexity > max_perplexity:
439
+ print(f"Retrying: High perplexity ({perplexity:.2f} > {max_perplexity}).\nCaption: {caption!r}")
440
+ else:
441
  return caption
 
442
 
443
  def _get_prompt_string(self, caption_type, caption_tone, caption_length):
444
  length = None if caption_length == "any" else caption_length
 
574
  )[0]
575
  return caption
576
 
577
+ def _calculate_entropy(self, token_ids: List[int]) -> float:
578
+ """
579
+ Calculate the entropy of a sequence of token IDs.
580
+
581
+ Args:
582
+ token_ids (List[int]): List of token IDs.
583
+
584
+ Returns:
585
+ float: Entropy of the token sequence.
586
+ """
587
+ token_counts = {}
588
+ total_tokens = len(token_ids)
589
+
590
+ for token_id in token_ids:
591
+ token_counts[token_id] = token_counts.get(token_id, 0) + 1
592
+
593
+ entropy = 0
594
+ for count in token_counts.values():
595
+ probability = count / total_tokens
596
+ entropy -= probability * math.log2(probability)
597
+
598
+ return entropy
599
+
600
+ def _calculate_perplexity(self, generate_ids, input_ids):
601
+ """
602
+ Calculate the perplexity of the generated caption.
603
+
604
+ Args:
605
+ generate_ids (torch.Tensor): Generated token IDs.
606
+ input_ids (torch.Tensor): Input token IDs.
607
+
608
+ Returns:
609
+ float: Perplexity of the generated caption.
610
+ """
611
+ with torch.no_grad():
612
+ outputs = self.text_model(
613
+ input_ids=input_ids,
614
+ labels=generate_ids,
615
+ output_hidden_states=True,
616
+ )
617
+ loss = outputs.loss
618
+ return loss.item()
619
+
620
 
621
  def main():
622
  """
 
738
  args, image_path, tagset_normalizer
739
  )
740
 
741
+ print(f"\nCaptioning {image_path}...\nCustom prompt: {custom_prompt}")
742
 
743
  caption = joy_caption_model.generate_valid_caption(
744
  input_image,
 
968
  custom_prompt = ' '.join(s for s in [
969
  "Write a descriptive caption for this image",
970
  artist_txt, species_txt, character_txt, copyright_txt,
971
+ "in a formal tone. Limit yourself to two paragraphs, avoid repeating yourself and think before you type anything. Use these tags to construct your caption:",
972
  tag_string,
973
  ] if s)
974
  return custom_prompt
remove_extra_whitespace ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ This script removes all extra spaces (more than one) and new line characters (truncating to one single character)
6
+ from all *.caption and *.txt files in a target directory recursively. If no target directory is provided as an
7
+ argument, it processes the current directory.
8
+
9
+ Usage:
10
+ python script_name.py [target_directory]
11
+
12
+ Args:
13
+ target_directory (str, optional): The path to the target directory. If not provided, the current directory is used.
14
+ """
15
+
16
+ import os
17
+ import sys
18
+ import glob
19
+
20
+ def remove_extra_spaces_and_newlines(file_path):
21
+ """
22
+ Removes extra spaces (more than one) and new line characters from the given file.
23
+ Truncates the text to a single space or new line character without removing any text.
24
+
25
+ Args:
26
+ file_path (str): The path to the file to be processed.
27
+ """
28
+ with open(file_path, 'r', encoding='utf-8') as file:
29
+ content = file.read()
30
+
31
+ # Replace multiple spaces with a single space
32
+ content = ' '.join(content.split())
33
+
34
+ # Replace multiple newlines with a single newline
35
+ content = '\n'.join(line.strip() for line in content.split('\n'))
36
+
37
+ with open(file_path, 'w', encoding='utf-8') as file:
38
+ file.write(content)
39
+
40
+ def process_files_in_directory(directory):
41
+ """
42
+ Processes all *.caption and *.txt files in the given directory recursively.
43
+ Removes extra spaces and new line characters from each file.
44
+
45
+ Args:
46
+ directory (str): The path to the directory to be processed.
47
+ """
48
+ for file_path in glob.glob(os.path.join(directory, '**', '*.caption'), recursive=True):
49
+ remove_extra_spaces_and_newlines(file_path)
50
+ for file_path in glob.glob(os.path.join(directory, '**', '*.txt'), recursive=True):
51
+ remove_extra_spaces_and_newlines(file_path)
52
+
53
+ if __name__ == "__main__":
54
+ if len(sys.argv) > 1:
55
+ target_directory = sys.argv[1]
56
+ else:
57
+ target_directory = os.getcwd()
58
+
59
+ process_files_in_directory(target_directory)
60
+