joy updates (we havent tested yet) and remove_extra_whitespace
Browse files- joy +143 -24
- remove_extra_whitespace +60 -0
joy
CHANGED
@@ -18,6 +18,7 @@ import os
|
|
18 |
import argparse
|
19 |
import re
|
20 |
import random
|
|
|
21 |
from pathlib import Path
|
22 |
from typing import List, Tuple, Dict
|
23 |
from PIL import Image
|
@@ -199,6 +200,33 @@ class ImageAdapter(nn.Module):
|
|
199 |
torch.tensor([2], device=self.other_tokens.weight.device)
|
200 |
).squeeze(0)
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
class JoyCaptionModel:
|
204 |
"""
|
@@ -302,11 +330,20 @@ class JoyCaptionModel:
|
|
302 |
caption_type: str,
|
303 |
caption_tone: str,
|
304 |
caption_length: str | int,
|
305 |
-
custom_prompt: str | None = None
|
306 |
-
) -> str:
|
307 |
"""
|
308 |
-
Process
|
309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
"""
|
311 |
torch.cuda.empty_cache()
|
312 |
|
@@ -326,12 +363,18 @@ class JoyCaptionModel:
|
|
326 |
embedded_images, prompt
|
327 |
)
|
328 |
|
329 |
-
generate_ids = self._generate_caption(inputs_embeds,
|
330 |
-
input_ids,
|
331 |
-
attention_mask)
|
332 |
caption = self._decode_caption(generate_ids, input_ids)
|
333 |
|
334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
|
336 |
def generate_valid_caption(
|
337 |
self,
|
@@ -340,29 +383,62 @@ class JoyCaptionModel:
|
|
340 |
caption_tone: str,
|
341 |
caption_length: str | int,
|
342 |
custom_prompt: str | None = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
) -> str:
|
344 |
"""
|
345 |
-
Generate a valid caption, retrying if
|
346 |
-
|
347 |
-
|
348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
"""
|
350 |
while True:
|
351 |
-
caption = self.process_image(
|
352 |
input_image, caption_type, caption_tone,
|
353 |
caption_length, custom_prompt
|
354 |
)
|
355 |
-
words = re.findall(r'\b\w
|
356 |
-
word_counts = {word: words.count(word) for word in set(words)}
|
357 |
sentence_count = len(re.findall(r'[.!?]', caption))
|
358 |
-
|
359 |
-
if
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
return caption
|
365 |
-
print(f"Generated caption is invalid. Retrying...\nCaption: {caption!r}")
|
366 |
|
367 |
def _get_prompt_string(self, caption_type, caption_tone, caption_length):
|
368 |
length = None if caption_length == "any" else caption_length
|
@@ -498,6 +574,49 @@ class JoyCaptionModel:
|
|
498 |
)[0]
|
499 |
return caption
|
500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
|
502 |
def main():
|
503 |
"""
|
@@ -619,7 +738,7 @@ def main():
|
|
619 |
args, image_path, tagset_normalizer
|
620 |
)
|
621 |
|
622 |
-
print(f"
|
623 |
|
624 |
caption = joy_caption_model.generate_valid_caption(
|
625 |
input_image,
|
@@ -849,7 +968,7 @@ def prompt_from_tags(args, image_path: Path,
|
|
849 |
custom_prompt = ' '.join(s for s in [
|
850 |
"Write a descriptive caption for this image",
|
851 |
artist_txt, species_txt, character_txt, copyright_txt,
|
852 |
-
"in a formal tone. Use these tags to construct your caption:",
|
853 |
tag_string,
|
854 |
] if s)
|
855 |
return custom_prompt
|
|
|
18 |
import argparse
|
19 |
import re
|
20 |
import random
|
21 |
+
import math
|
22 |
from pathlib import Path
|
23 |
from typing import List, Tuple, Dict
|
24 |
from PIL import Image
|
|
|
200 |
torch.tensor([2], device=self.other_tokens.weight.device)
|
201 |
).squeeze(0)
|
202 |
|
203 |
+
STOP_WORDS: set[str] = {
|
204 |
+
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
|
205 |
+
"of", "with", "by", "from", "up", "down", "is", "are", "was", "were",
|
206 |
+
"be", "been", "being", "have", "has", "had", "do", "does", "did",
|
207 |
+
"will", "would", "shall", "should", "can", "could", "may", "might",
|
208 |
+
"must", "ought", "i", "you", "he", "she", "it", "we", "they", "them",
|
209 |
+
"their", "this", "that", "these", "those", "am", "is", "are", "was",
|
210 |
+
"were", "be", "been", "being", "have", "has", "had", "do", "does",
|
211 |
+
"did", "will", "would", "shall", "should", "can", "could", "may",
|
212 |
+
"might", "must", "ought", "i'm", "you're", "he's", "she's", "it's",
|
213 |
+
"we're", "they're", "i've", "you've", "we've", "they've", "i'd",
|
214 |
+
"you'd", "he'd", "she'd", "we'd", "they'd", "i'll", "you'll",
|
215 |
+
"he'll", "she'll", "we'll", "they'll", "isn't", "aren't", "wasn't",
|
216 |
+
"weren't", "hasn't", "haven't", "hadn't", "doesn't", "don't",
|
217 |
+
"didn't", "won't", "wouldn't", "shan't", "shouldn't", "can't",
|
218 |
+
"cannot", "couldn't", "mustn't", "let's", "that's", "who's",
|
219 |
+
"what's", "here's", "there's", "when's", "where's", "why's", "how's",
|
220 |
+
"a", "an", "the", "and", "but", "if", "or", "because", "as", "until",
|
221 |
+
"while", "of", "at", "by", "for", "with", "about", "against",
|
222 |
+
"between", "into", "through", "during", "before", "after", "above",
|
223 |
+
"below", "to", "from", "up", "down", "in", "out", "on", "off", "over",
|
224 |
+
"under", "again", "further", "then", "once", "here", "there", "when",
|
225 |
+
"where", "why", "how", "all", "any", "both", "each", "few", "more",
|
226 |
+
"most", "other", "some", "such", "no", "nor", "not", "only", "own",
|
227 |
+
"same", "so", "than", "too", "very"
|
228 |
+
}
|
229 |
+
|
230 |
|
231 |
class JoyCaptionModel:
|
232 |
"""
|
|
|
330 |
caption_type: str,
|
331 |
caption_tone: str,
|
332 |
caption_length: str | int,
|
333 |
+
custom_prompt: str | None = None
|
334 |
+
) -> Tuple[str, float, float]:
|
335 |
"""
|
336 |
+
Process the input image and generate a caption.
|
337 |
+
|
338 |
+
Args:
|
339 |
+
input_image (Image.Image): The input image to caption.
|
340 |
+
caption_type (str): The type of caption to generate.
|
341 |
+
caption_tone (str): The tone of the caption.
|
342 |
+
caption_length (str | int): The desired length of the caption.
|
343 |
+
custom_prompt (str | None): A custom prompt for caption generation.
|
344 |
+
|
345 |
+
Returns:
|
346 |
+
Tuple[str, float, float]: A tuple containing the generated caption, its entropy, and its perplexity.
|
347 |
"""
|
348 |
torch.cuda.empty_cache()
|
349 |
|
|
|
363 |
embedded_images, prompt
|
364 |
)
|
365 |
|
366 |
+
generate_ids = self._generate_caption(inputs_embeds, input_ids, attention_mask)
|
|
|
|
|
367 |
caption = self._decode_caption(generate_ids, input_ids)
|
368 |
|
369 |
+
# Calculate entropy
|
370 |
+
token_ids = generate_ids[0].tolist()
|
371 |
+
entropy = self._calculate_entropy(token_ids)
|
372 |
+
|
373 |
+
# Calculate perplexity
|
374 |
+
loss = self._calculate_perplexity(generate_ids, input_ids)
|
375 |
+
perplexity = math.exp(-loss)
|
376 |
+
|
377 |
+
return caption.strip(), entropy, perplexity
|
378 |
|
379 |
def generate_valid_caption(
|
380 |
self,
|
|
|
383 |
caption_tone: str,
|
384 |
caption_length: str | int,
|
385 |
custom_prompt: str | None = None,
|
386 |
+
*,
|
387 |
+
limited_words: Dict[str, int] = {"fluffy": 2},
|
388 |
+
min_sentence_count: int = 3,
|
389 |
+
max_word_repetitions: int = 5,
|
390 |
+
min_entropy: float = 1.75,
|
391 |
+
max_perplexity: float = 100.0,
|
392 |
+
stop_words: set[str] = STOP_WORDS
|
393 |
) -> str:
|
394 |
"""
|
395 |
+
Generate a valid caption, retrying if certain conditions are not met.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
input_image (Image.Image): The input image to caption.
|
399 |
+
caption_type (str): The type of caption to generate.
|
400 |
+
caption_tone (str): The tone of the caption.
|
401 |
+
caption_length (str | int): The desired length of the caption.
|
402 |
+
custom_prompt (str | None): A custom prompt for caption generation.
|
403 |
+
limited_words (Dict[str, int]): Dictionary of words with their maximum allowed occurrences. Default is {"fluffy": 2}.
|
404 |
+
min_sentence_count (int): Minimum required number of sentences. Default is 3.
|
405 |
+
max_word_repetitions (int): Maximum allowed repetitions for words longer than 4 characters. Default is 5.
|
406 |
+
min_entropy (float): Minimum required entropy of the caption. Default is 1.75.
|
407 |
+
max_perplexity (float): Maximum allowed perplexity of the caption. Default is 100.0.
|
408 |
+
stop_words (set[str]): Set of stop words to exclude from repetition checks. Default is STOP_WORDS.
|
409 |
+
|
410 |
+
Returns:
|
411 |
+
str: A valid caption meeting all specified criteria.
|
412 |
"""
|
413 |
while True:
|
414 |
+
caption, entropy, perplexity = self.process_image(
|
415 |
input_image, caption_type, caption_tone,
|
416 |
caption_length, custom_prompt
|
417 |
)
|
418 |
+
words = re.findall(r'\b\w+\b', caption.lower())
|
419 |
+
word_counts = {word: words.count(word) for word in set(words) if word not in stop_words}
|
420 |
sentence_count = len(re.findall(r'[.!?]', caption))
|
421 |
+
|
422 |
+
if not re.search(r'\w', caption):
|
423 |
+
print(f"Retrying: Caption contains only special characters.\nCaption: {caption!r}")
|
424 |
+
elif caption[-1] not in {'.', '!', '?'}:
|
425 |
+
print(f"Retrying: Caption does not end with proper punctuation.\nCaption: {caption!r}")
|
426 |
+
elif any(caption.lower().count(word) > max_count for word, max_count in limited_words.items()):
|
427 |
+
exceeded_words = [f"{word} ({caption.lower().count(word)}/{max_count})"
|
428 |
+
for word, max_count in limited_words.items()
|
429 |
+
if caption.lower().count(word) > max_count]
|
430 |
+
print(f"Retrying: Limited words exceeded: {', '.join(exceeded_words)}.\nCaption: {caption!r}")
|
431 |
+
elif any(count > max_word_repetitions for word, count in word_counts.items() if len(word) > 4):
|
432 |
+
repeated_words = [word for word, count in word_counts.items() if count > max_word_repetitions and len(word) > 4]
|
433 |
+
print(f"Retrying: Words repeated more than {max_word_repetitions} times: {', '.join(repeated_words)}.\nCaption: {caption!r}")
|
434 |
+
elif sentence_count < min_sentence_count:
|
435 |
+
print(f"Retrying: Only {sentence_count} sentences (min: {min_sentence_count}).\nCaption: {caption!r}")
|
436 |
+
elif entropy < min_entropy:
|
437 |
+
print(f"Retrying: Low entropy ({entropy:.2f} < {min_entropy}).\nCaption: {caption!r}")
|
438 |
+
elif perplexity > max_perplexity:
|
439 |
+
print(f"Retrying: High perplexity ({perplexity:.2f} > {max_perplexity}).\nCaption: {caption!r}")
|
440 |
+
else:
|
441 |
return caption
|
|
|
442 |
|
443 |
def _get_prompt_string(self, caption_type, caption_tone, caption_length):
|
444 |
length = None if caption_length == "any" else caption_length
|
|
|
574 |
)[0]
|
575 |
return caption
|
576 |
|
577 |
+
def _calculate_entropy(self, token_ids: List[int]) -> float:
|
578 |
+
"""
|
579 |
+
Calculate the entropy of a sequence of token IDs.
|
580 |
+
|
581 |
+
Args:
|
582 |
+
token_ids (List[int]): List of token IDs.
|
583 |
+
|
584 |
+
Returns:
|
585 |
+
float: Entropy of the token sequence.
|
586 |
+
"""
|
587 |
+
token_counts = {}
|
588 |
+
total_tokens = len(token_ids)
|
589 |
+
|
590 |
+
for token_id in token_ids:
|
591 |
+
token_counts[token_id] = token_counts.get(token_id, 0) + 1
|
592 |
+
|
593 |
+
entropy = 0
|
594 |
+
for count in token_counts.values():
|
595 |
+
probability = count / total_tokens
|
596 |
+
entropy -= probability * math.log2(probability)
|
597 |
+
|
598 |
+
return entropy
|
599 |
+
|
600 |
+
def _calculate_perplexity(self, generate_ids, input_ids):
|
601 |
+
"""
|
602 |
+
Calculate the perplexity of the generated caption.
|
603 |
+
|
604 |
+
Args:
|
605 |
+
generate_ids (torch.Tensor): Generated token IDs.
|
606 |
+
input_ids (torch.Tensor): Input token IDs.
|
607 |
+
|
608 |
+
Returns:
|
609 |
+
float: Perplexity of the generated caption.
|
610 |
+
"""
|
611 |
+
with torch.no_grad():
|
612 |
+
outputs = self.text_model(
|
613 |
+
input_ids=input_ids,
|
614 |
+
labels=generate_ids,
|
615 |
+
output_hidden_states=True,
|
616 |
+
)
|
617 |
+
loss = outputs.loss
|
618 |
+
return loss.item()
|
619 |
+
|
620 |
|
621 |
def main():
|
622 |
"""
|
|
|
738 |
args, image_path, tagset_normalizer
|
739 |
)
|
740 |
|
741 |
+
print(f"\nCaptioning {image_path}...\nCustom prompt: {custom_prompt}")
|
742 |
|
743 |
caption = joy_caption_model.generate_valid_caption(
|
744 |
input_image,
|
|
|
968 |
custom_prompt = ' '.join(s for s in [
|
969 |
"Write a descriptive caption for this image",
|
970 |
artist_txt, species_txt, character_txt, copyright_txt,
|
971 |
+
"in a formal tone. Limit yourself to two paragraphs, avoid repeating yourself and think before you type anything. Use these tags to construct your caption:",
|
972 |
tag_string,
|
973 |
] if s)
|
974 |
return custom_prompt
|
remove_extra_whitespace
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""
|
5 |
+
This script removes all extra spaces (more than one) and new line characters (truncating to one single character)
|
6 |
+
from all *.caption and *.txt files in a target directory recursively. If no target directory is provided as an
|
7 |
+
argument, it processes the current directory.
|
8 |
+
|
9 |
+
Usage:
|
10 |
+
python script_name.py [target_directory]
|
11 |
+
|
12 |
+
Args:
|
13 |
+
target_directory (str, optional): The path to the target directory. If not provided, the current directory is used.
|
14 |
+
"""
|
15 |
+
|
16 |
+
import os
|
17 |
+
import sys
|
18 |
+
import glob
|
19 |
+
|
20 |
+
def remove_extra_spaces_and_newlines(file_path):
|
21 |
+
"""
|
22 |
+
Removes extra spaces (more than one) and new line characters from the given file.
|
23 |
+
Truncates the text to a single space or new line character without removing any text.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
file_path (str): The path to the file to be processed.
|
27 |
+
"""
|
28 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
29 |
+
content = file.read()
|
30 |
+
|
31 |
+
# Replace multiple spaces with a single space
|
32 |
+
content = ' '.join(content.split())
|
33 |
+
|
34 |
+
# Replace multiple newlines with a single newline
|
35 |
+
content = '\n'.join(line.strip() for line in content.split('\n'))
|
36 |
+
|
37 |
+
with open(file_path, 'w', encoding='utf-8') as file:
|
38 |
+
file.write(content)
|
39 |
+
|
40 |
+
def process_files_in_directory(directory):
|
41 |
+
"""
|
42 |
+
Processes all *.caption and *.txt files in the given directory recursively.
|
43 |
+
Removes extra spaces and new line characters from each file.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
directory (str): The path to the directory to be processed.
|
47 |
+
"""
|
48 |
+
for file_path in glob.glob(os.path.join(directory, '**', '*.caption'), recursive=True):
|
49 |
+
remove_extra_spaces_and_newlines(file_path)
|
50 |
+
for file_path in glob.glob(os.path.join(directory, '**', '*.txt'), recursive=True):
|
51 |
+
remove_extra_spaces_and_newlines(file_path)
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
if len(sys.argv) > 1:
|
55 |
+
target_directory = sys.argv[1]
|
56 |
+
else:
|
57 |
+
target_directory = os.getcwd()
|
58 |
+
|
59 |
+
process_files_in_directory(target_directory)
|
60 |
+
|