|
|
|
|
|
|
|
|
|
import re |
|
from typing import Any, Union, List |
|
|
|
import numpy as np |
|
from PIL import Image |
|
from transformers import BaseImageProcessor, LlavaProcessor, PreTrainedTokenizer |
|
from transformers.models.llava.processing_llava import LlavaProcessorKwargs |
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array |
|
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order |
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput |
|
|
|
|
|
|
|
|
|
|
|
|
|
class Maira2Processor(LlavaProcessor): |
|
""" |
|
Constructs a Maira2 processor similar to LlavaProcessor but with additional arguments and functions to support |
|
multi-image grounded and non-grounded radiology report generation. |
|
|
|
In addition to the arguments of LlavaProcessor, Maira2Processor has the following extra arguments: |
|
|
|
Args: |
|
phrase_start_token (`str`, *optional*, defaults to `"<obj>"`): |
|
Special token used to denote the start of a grounded phrase (with or without box). |
|
phrase_end_token (`str`, *optional*, defaults to `"</obj>"`): |
|
Special token used to denote the end of a grounded phrase. |
|
box_start_token (`str`, *optional*, defaults to `"<box>"`): |
|
Special token used to denote the start of a bounding box. |
|
box_end_token (`str`, *optional*, defaults to `"</box>"`): |
|
Special token used to denote the end of a bounding box. |
|
num_box_coord_bins (`int`, *optional*, defaults to `100`): |
|
Number of bins used to represent the bounding box coordinates. |
|
""" |
|
|
|
valid_kwargs = [ |
|
"chat_template", |
|
"patch_size", |
|
"vision_feature_select_strategy", |
|
"image_token", |
|
"phrase_start_token", |
|
"phrase_end_token", |
|
"box_start_token", |
|
"box_end_token", |
|
"num_box_coord_bins", |
|
] |
|
|
|
def __init__( |
|
self, |
|
image_processor: BaseImageProcessor = None, |
|
tokenizer: PreTrainedTokenizer = None, |
|
patch_size = None, |
|
vision_feature_select_strategy = None, |
|
chat_template = None, |
|
image_token: str = "<image>", |
|
phrase_start_token: str = "<obj>", |
|
phrase_end_token: str = "</obj>", |
|
box_start_token: str = "<box>", |
|
box_end_token: str = "</box>", |
|
num_box_coord_bins: int = 100, |
|
**kwargs: Any, |
|
) -> None: |
|
super().__init__( |
|
image_processor=image_processor, |
|
tokenizer=tokenizer, |
|
patch_size=patch_size, |
|
vision_feature_select_strategy=vision_feature_select_strategy, |
|
chat_template=chat_template, |
|
image_token=image_token, |
|
**kwargs, |
|
) |
|
|
|
self.phrase_start_token = phrase_start_token |
|
self.phrase_end_token = phrase_end_token |
|
self.box_start_token = box_start_token |
|
self.box_end_token = box_end_token |
|
self.num_box_coord_bins = num_box_coord_bins |
|
|
|
@staticmethod |
|
def _normalize_image(image: Image.Image) -> Image.Image: |
|
""" |
|
This function normalizes the input image to have pixel values in the range [0, 255]. |
|
|
|
Args: |
|
image (Image.Image | np.ndarray): |
|
The input image to be normalized. |
|
|
|
Returns: |
|
Image.Image: The normalized image in grayscale. |
|
""" |
|
image_np = np.array(image.convert("L")) |
|
image_np = image_np.astype(float) |
|
image_np -= image_np.min() |
|
image_np /= image_np.max() |
|
image_np *= 255 |
|
image_np = image_np.astype(np.uint8) |
|
|
|
return Image.fromarray(image_np).convert("L") |
|
|
|
def _normalize_and_stack_images( |
|
self, |
|
current_frontal: Image.Image, |
|
current_lateral: Image.Image, |
|
prior_frontal: Image.Image, |
|
): |
|
""" |
|
This function normalizes the input images and stacks them together. The images are stacked in the order of |
|
current_frontal, current_lateral, and prior_frontal. The order of images is important, since it must match the |
|
order of the images in the prompt, which is frontal, then lateral then prior. |
|
|
|
Args: |
|
current_frontal (Image.Image): |
|
The current frontal image. |
|
current_lateral (Image.Image | None): |
|
The current lateral image. |
|
prior_frontal (Image.Image | None): |
|
The prior frontal image. |
|
|
|
Returns: |
|
list[Image.Image]: The normalized images stacked together. |
|
""" |
|
images = [self._normalize_image(current_frontal)] |
|
if current_lateral is not None: |
|
images.append(self._normalize_image(current_lateral)) |
|
if prior_frontal is not None: |
|
images.append(self._normalize_image(prior_frontal)) |
|
return images |
|
|
|
@staticmethod |
|
def _get_section_text_or_missing_text(section: str) -> str: |
|
""" |
|
This function returns the input section text if it is not None and not empty, otherwise it returns a missing |
|
section text "N/A". |
|
|
|
Args: |
|
section (str | None): |
|
The input section text. |
|
|
|
Returns: |
|
str: The section text if it is not None and not empty, otherwise "N/A". |
|
""" |
|
missing_section_text = "N/A" |
|
if not isinstance(section, str) or len(section) == 0: |
|
return missing_section_text |
|
return section |
|
|
|
@staticmethod |
|
def _construct_image_chat_messages_for_reporting(has_prior: bool, has_lateral: bool): |
|
""" |
|
This function constructs user chat messages based on the presence of the prior and lateral images. |
|
|
|
Args: |
|
has_prior (bool): |
|
A boolean indicating whether the prior image is present. |
|
has_lateral (bool): |
|
A boolean indicating whether the lateral image is present. |
|
|
|
Returns: |
|
list[SingleChatMessageType]: The image prompt messages in the form of a list of dictionaries. |
|
|
|
Example: |
|
|
|
```python |
|
>>> _construct_image_chat_messages_for_reporting(has_prior=True, has_lateral=True) |
|
>>> # [ |
|
>>> # {"index": None, "text": "Given the current frontal image", "type": "text"}, |
|
>>> # {"index": 0, "text": None, "type": "image"}, |
|
>>> # {"index": None, "text": " the current lateral image", "type": "text"}, |
|
>>> # {"index": 1, "text": None, "type": "image"}, |
|
>>> # {"index": None, "text": " and the prior frontal image", "type": "text"}, |
|
>>> # {"index": 2, "text": None, "type": "image"}, |
|
>>> # ] |
|
``` |
|
""" |
|
|
|
def _add_single_image_to_chat_messages(prompt_text: str, image_index: int) -> None: |
|
image_prompt.extend( |
|
[ |
|
{"index": None, "text": prompt_text, "type": "text"}, |
|
{"index": image_index, "text": None, "type": "image"}, |
|
] |
|
) |
|
|
|
image_prompt = [] |
|
image_index = 0 |
|
if not has_prior and not has_lateral: |
|
_add_single_image_to_chat_messages("Given the current frontal image only", image_index) |
|
else: |
|
_add_single_image_to_chat_messages("Given the current frontal image", image_index) |
|
image_index += 1 |
|
if has_prior: |
|
if has_lateral: |
|
_add_single_image_to_chat_messages(" the current lateral image", image_index) |
|
image_index += 1 |
|
_add_single_image_to_chat_messages(" and the prior frontal image", image_index) |
|
else: |
|
if has_lateral: |
|
_add_single_image_to_chat_messages(" and the current lateral image", image_index) |
|
return image_prompt |
|
|
|
def _construct_chat_messages_reporting( |
|
self, |
|
has_prior: bool, |
|
has_lateral: bool, |
|
indication: str, |
|
technique: str, |
|
comparison: str, |
|
prior_report: str, |
|
get_grounding: bool = False, |
|
assistant_text: str = None, |
|
): |
|
""" |
|
This function constructs the chat messages for reporting used in the grounded and non-grounded reporting tasks. |
|
|
|
Args: |
|
has_prior (bool): |
|
A boolean indicating whether the prior image is present. |
|
has_lateral (bool): |
|
A boolean indicating whether the lateral image is present. |
|
indication (str | None): |
|
The indication section text. |
|
technique (str | None): |
|
The technique section text. |
|
comparison (str | None): |
|
The comparison section text. |
|
prior_report (str | None): |
|
The prior report section text. |
|
get_grounding (bool): |
|
A boolean indicating whether to get the grounding information. |
|
assistant_text (str | None): |
|
The assistant text (can be set to None for ordinary inference). |
|
|
|
Returns: |
|
ChatMessageListType: The chat messages for reporting in the form of a list of dictionaries. |
|
|
|
Example: |
|
|
|
```python |
|
>>> _construct_chat_messages_reporting( |
|
>>> has_prior=True, |
|
>>> has_lateral=True, |
|
>>> indication="indication text from report goes here", |
|
>>> technique="technique text from report goes here", |
|
>>> comparison="comparison text from report goes here", |
|
>>> prior_report="prior reporting text goes here", |
|
>>> get_grounding=False, |
|
>>> assistant_text=None, |
|
>>> ) |
|
>>> # [ |
|
>>> # {"index": None, "text": "Given the current frontal image", "type": "text"}, |
|
>>> # {"index": 0, "text": None, "type": "image"}, |
|
>>> # {"index": None, "text": " the current lateral image", "type": "text"}, |
|
>>> # {"index": 1, "text": None, "type": "image"}, |
|
>>> # {"index": None, "text": " and the prior frontal image", "type": "text"}, |
|
>>> # {"index": 2, "text": None, "type": "image"}, |
|
>>> # {"index": None, "text": " PRIOR_REPORT: prior reporting text goes here", "type": "text"}, |
|
>>> # {"index": None, "text": " Provide a description of the findings in the radiology study in comparison to the " |
|
>>> # "prior frontal image. INDICATION: indication text from report goes here TECHNIQUE: technique text from report " |
|
>>> # "goes here COMPARISON: comparison text from report goes here", "type": "text"}, |
|
>>> # ] |
|
``` |
|
""" |
|
indication = self._get_section_text_or_missing_text(indication) |
|
technique = self._get_section_text_or_missing_text(technique) |
|
comparison = self._get_section_text_or_missing_text(comparison) |
|
prior_report = self._get_section_text_or_missing_text(prior_report) |
|
|
|
prompt = self._construct_image_chat_messages_for_reporting(has_prior=has_prior, has_lateral=has_lateral) |
|
|
|
if has_prior: |
|
prompt.append({"index": None, "text": f" PRIOR_REPORT: {prior_report}", "type": "text"}) |
|
|
|
if get_grounding: |
|
prompt.append( |
|
{ |
|
"index": None, |
|
"text": " Provide a description of the findings in the radiology study in comparison to the " |
|
"prior frontal image. Each finding should be described as a self-contained plain-text sentence." |
|
" If the finding is groundable, locate the finding in the current frontal chest X-ray image, " |
|
"with bounding boxes indicating all locations where it can be seen in the current frontal " |
|
"image. Otherwise, generate just the ungrounded finding without bounding boxes. INDICATION: " |
|
f"{indication} TECHNIQUE: {technique} COMPARISON: {comparison}", |
|
"type": "text", |
|
} |
|
) |
|
else: |
|
prompt.append( |
|
{ |
|
"index": None, |
|
"text": " Provide a description of the findings in the radiology study in comparison to the " |
|
f"prior frontal image. INDICATION: {indication} TECHNIQUE: {technique} COMPARISON: " |
|
f"{comparison}", |
|
"type": "text", |
|
} |
|
) |
|
messages = [{"content": prompt, "role": "user"}] |
|
if assistant_text is not None: |
|
messages.append({"content": [{"index": None, "text": assistant_text, "type": "text"}], "role": "assistant"}) |
|
return messages |
|
|
|
def _construct_chat_messages_phrase_grounding( |
|
self, phrase: str, assistant_text: str = None |
|
): |
|
""" |
|
This function constructs the chat messages for phrase grounding used in the phrase grounding task. |
|
|
|
Args: |
|
phrase (str): |
|
The phrase to be grounded. |
|
assistant_text (str | None): |
|
The assistant text (can be set to None for ordinary inference). |
|
|
|
Returns: |
|
ChatMessageListType: The chat messages for phrase grounding in the form of a list of dictionaries. |
|
""" |
|
prompt = [ |
|
{"index": None, "text": "Given the current frontal image", "type": "text"}, |
|
{"index": 0, "text": None, "type": "image"}, |
|
{ |
|
"index": None, |
|
"text": f" Repeat the following finding as a grounded phrase with bounding boxes indicating all " |
|
f"locations where it can be seen in the given chest X-ray image. Finding: {phrase}", |
|
"type": "text", |
|
}, |
|
] |
|
messages = [{"content": prompt, "role": "user"}] |
|
if assistant_text is not None: |
|
messages.append({"content": [{"index": None, "text": assistant_text, "type": "text"}], "role": "assistant"}) |
|
return messages |
|
|
|
def format_reporting_input( |
|
self, |
|
current_frontal: Image.Image, |
|
current_lateral: Image.Image, |
|
prior_frontal: Image.Image, |
|
indication: str, |
|
technique: str, |
|
comparison: str, |
|
prior_report: str, |
|
get_grounding: bool = False, |
|
assistant_text: str = None, |
|
): |
|
""" |
|
This function formats the reporting prompt for the grounded and non-grounded reporting tasks from the given |
|
input images and text sections. The images are normalized and stacked together in the right order. |
|
|
|
Args: |
|
current_frontal (Image.Image): |
|
The current frontal image. |
|
current_lateral (Image.Image | None): |
|
The current lateral image. |
|
prior_frontal (Image.Image | None): |
|
The prior frontal image. |
|
indication (str | None): |
|
The indication section text. |
|
technique (str | None): |
|
The technique section text. |
|
comparison (str | None): |
|
The comparison section text. |
|
prior_report (str | None): |
|
The prior report section text. |
|
get_grounding (bool): |
|
A boolean indicating whether to construct the prompt for grounded or non-grounded reporting. |
|
assistant_text (str | None): The assistant text (can be set to None for ordinary inference). |
|
|
|
Returns: |
|
tuple[str, list[Image.Image]]: The formatted prompt text and the normalized images stacked in the right order. |
|
""" |
|
images = self._normalize_and_stack_images( |
|
current_frontal=current_frontal, |
|
current_lateral=current_lateral, |
|
prior_frontal=prior_frontal, |
|
) |
|
messages = self._construct_chat_messages_reporting( |
|
has_prior=prior_frontal is not None, |
|
has_lateral=current_lateral is not None, |
|
indication=indication, |
|
technique=technique, |
|
comparison=comparison, |
|
prior_report=prior_report, |
|
get_grounding=get_grounding, |
|
assistant_text=assistant_text, |
|
) |
|
add_generation_prompt = assistant_text is None |
|
text = self.tokenizer.apply_chat_template(messages, add_generation_prompt=add_generation_prompt, tokenize=False) |
|
return text, images |
|
|
|
def format_phrase_grounding_input( |
|
self, |
|
frontal_image: Image.Image, |
|
phrase: str, |
|
assistant_text: str = None, |
|
): |
|
""" |
|
This function formats the phrase grounding prompt for the phrase grounding task from the given input |
|
image and phrase. |
|
|
|
Args: |
|
frontal_image (Image.Image): |
|
The frontal image. |
|
phrase (str): |
|
The phrase to be grounded. |
|
assistant_text (str | None): |
|
The assistant text (can be set to None for ordinary inference). |
|
|
|
Returns: |
|
tuple[str, list[Image.Image]]: The formatted phrase grounding prompt text and the normalized image. |
|
""" |
|
images = self._normalize_and_stack_images( |
|
current_frontal=frontal_image, |
|
current_lateral=None, |
|
prior_frontal=None, |
|
) |
|
messages = self._construct_chat_messages_phrase_grounding(phrase) |
|
add_generation_prompt = assistant_text is None |
|
text = self.tokenizer.apply_chat_template(messages, add_generation_prompt=add_generation_prompt, tokenize=False) |
|
return text, images |
|
|
|
def format_and_preprocess_reporting_input( |
|
self, |
|
current_frontal: Image.Image, |
|
current_lateral: Image.Image, |
|
prior_frontal: Image.Image, |
|
indication: str, |
|
technique: str, |
|
comparison: str, |
|
prior_report: str, |
|
get_grounding: bool = False, |
|
assistant_text: str = None, |
|
**kwargs: Any, |
|
) -> BatchFeature: |
|
""" |
|
This function formats and then preprocesses the input for the grounded and non-grounded reporting tasks from |
|
the given input images and text sections and returns the batch feature for the model. It calls format_reporting_input |
|
internally to format the input prompt and stack the images together in the right order. |
|
|
|
Args: |
|
current_frontal (Image.Image): |
|
The current frontal image. |
|
current_lateral (Image.Image | None): |
|
The current lateral image. |
|
prior_frontal (Image.Image | None): |
|
The prior frontal image. |
|
indication (str | None): |
|
The indication section text. |
|
technique (str | None): |
|
The technique section text. |
|
comparison (str | None): |
|
The comparison section text. |
|
prior_report (str | None): |
|
The prior report section text. |
|
get_grounding (bool): |
|
A boolean indicating whether to preprocess the input for grounded or non-grounded reporting. |
|
assistant_text (str | None): |
|
The assistant text (can be set to None for ordinary inference). |
|
|
|
Returns: |
|
BatchFeature: The batch feature for the model, ready to be passed to the model. |
|
|
|
""" |
|
text, images = self.format_reporting_input( |
|
current_frontal=current_frontal, |
|
current_lateral=current_lateral, |
|
prior_frontal=prior_frontal, |
|
indication=indication, |
|
technique=technique, |
|
comparison=comparison, |
|
prior_report=prior_report, |
|
get_grounding=get_grounding, |
|
assistant_text=assistant_text, |
|
) |
|
return self(text=text, images=images, **kwargs) |
|
|
|
def format_and_preprocess_phrase_grounding_input( |
|
self, |
|
frontal_image: Image.Image, |
|
phrase: str, |
|
assistant_text: str = None, |
|
**kwargs: Any, |
|
) -> BatchFeature: |
|
""" |
|
This function formats and then processes the input for the phrase grounding task from the given input image and |
|
phrase and returns the batch feature for the model. It calls format_phrase_grounding_input internally to format |
|
the input prompt and normalize the image. |
|
|
|
Args: |
|
frontal_image (Image.Image): |
|
The frontal image. |
|
phrase (str): |
|
The phrase to be grounded. |
|
assistant_text (str | None): |
|
The assistant text (can be set to None for ordinary inference). |
|
|
|
Returns: |
|
BatchFeature: The batch feature for the model, ready to be passed to the model. |
|
""" |
|
text, images = self.format_phrase_grounding_input( |
|
frontal_image=frontal_image, |
|
phrase=phrase, |
|
assistant_text=assistant_text, |
|
) |
|
return self(text=text, images=images, **kwargs) |
|
|
|
def _get_text_between_delimiters(self, text: str, begin_token: str, end_token: str): |
|
""" |
|
This function splits the input text into a list of substrings beased on the given begin and end tokens. |
|
|
|
Args: |
|
text (str): |
|
The input text to be split. |
|
begin_token (str): |
|
The begin token. |
|
end_token (str): |
|
The end token. |
|
|
|
Returns: |
|
list[str]: The list of substrings between the given begin and end tokens. |
|
|
|
Example: |
|
|
|
```python |
|
>>> _get_text_between_delimiters("<obj>This is a grounded phrase</obj>. <obj>This is another grounded phrase</obj>.", "<obj>", "</obj>") |
|
>>> # ["grounded phrase", "This is another grounded phrase"] |
|
|
|
>>> _get_text_between_delimiters("<box><x10><y20><x30><y40></box><box><x50><y60><x70><y80></box>", "<box>", "</box>") |
|
>>> # ["<x10><y20><x30><y40>", "<x50><y60><x70><y80>"] |
|
``` |
|
""" |
|
split_text = [] |
|
while begin_token in text: |
|
assert text.startswith(begin_token) |
|
end_index = text.find(end_token) |
|
assert end_index != -1 |
|
split_text.append(text[len(begin_token) : end_index]) |
|
text = text[end_index + len(end_token) :] |
|
assert len(text) == 0 |
|
return split_text |
|
|
|
def convert_output_to_plaintext_or_grounded_sequence( |
|
self, text: str |
|
): |
|
""" |
|
This function converts the input text to a grounded sequence by extracting the grounded phrases and bounding |
|
boxes from the text. If the text is plaintext without any grounded phrases, it returns the text as is. |
|
|
|
Args: |
|
text (str): |
|
The input text to be converted. |
|
|
|
Returns: |
|
str | list[tuple[str, list[BoxType] | None]]: The grounded sequence. |
|
|
|
Example: |
|
|
|
```python |
|
>>> convert_output_to_plaintext_or_grounded_sequence("<obj>grounded phrase <box><x55><y45><x70><y56></box></obj><obj>ungrounded phrase</obj>") |
|
>>> # [ |
|
>>> # ("grounded phrase", [(0.55, 0.45, 0.70, 0.56)]), |
|
>>> # ("ungrounded phrase", None), |
|
>>> # ] |
|
|
|
>>> convert_output_to_plaintext_or_grounded_sequence("plain text") |
|
>>> # "plain text" |
|
``` |
|
""" |
|
text = text.strip() |
|
|
|
|
|
if not any( |
|
[ |
|
self.phrase_start_token in text, |
|
self.phrase_end_token in text, |
|
self.box_start_token in text, |
|
self.box_end_token in text, |
|
] |
|
): |
|
return text |
|
|
|
|
|
grounded_phrase_texts = self._get_text_between_delimiters(text, self.phrase_start_token, self.phrase_end_token) |
|
grounded_phrases = [] |
|
for grounded_phrase_text in grounded_phrase_texts: |
|
if self.box_start_token in grounded_phrase_text or self.box_end_token in grounded_phrase_text: |
|
first_box_start_index = grounded_phrase_text.find(self.box_start_token) |
|
phrase_text = grounded_phrase_text[:first_box_start_index].strip() |
|
boxes_text = grounded_phrase_text[first_box_start_index:] |
|
boxes_text_list = self._get_text_between_delimiters( |
|
boxes_text, self.box_start_token, self.box_end_token |
|
) |
|
boxes = [] |
|
for box_text in boxes_text_list: |
|
|
|
regex = r"<x(\d+?)><y(\d+?)><x(\d+?)><y(\d+?)>" |
|
match = re.search(regex, box_text) |
|
if match: |
|
x_min, y_min, x_max, y_max = match.groups() |
|
box = tuple( |
|
(int(coord) + 0.5) / self.num_box_coord_bins for coord in (x_min, y_min, x_max, y_max) |
|
) |
|
assert all(0 <= coord <= 1 for coord in box), f"Invalid box coordinates: {box}" |
|
boxes.append(box) |
|
else: |
|
raise ValueError(f"Invalid box coordinates: {box_text} not matching regex {regex}") |
|
grounded_phrases.append((phrase_text, boxes)) |
|
else: |
|
grounded_phrases.append((grounded_phrase_text.lstrip(), None)) |
|
return grounded_phrases |
|
|
|
@staticmethod |
|
def adjust_box_for_original_image_size(box, width: int, height: int): |
|
""" |
|
This function adjusts the bounding boxes from the MAIRA-2 model output to account for the image processor |
|
cropping the image to be square prior to the model forward pass. The box coordinates are adjusted to be |
|
relative to the original shape of the image assuming the image processor cropped the image based on the length |
|
of the shortest side. |
|
|
|
Args: |
|
box (BoxType): |
|
The box to be adjusted, normalised to (0, 1). |
|
width (int): |
|
Original width of the image, in pixels. |
|
height (int): |
|
Original height of the image, in pixels. |
|
|
|
Returns: |
|
BoxType: The box normalised relative to the original size of the image. |
|
""" |
|
crop_width = crop_height = min(width, height) |
|
x_offset = (width - crop_width) // 2 |
|
y_offset = (height - crop_height) // 2 |
|
|
|
norm_x_min, norm_y_min, norm_x_max, norm_y_max = box |
|
|
|
abs_x_min = int(norm_x_min * crop_width + x_offset) |
|
abs_x_max = int(norm_x_max * crop_width + x_offset) |
|
abs_y_min = int(norm_y_min * crop_height + y_offset) |
|
abs_y_max = int(norm_y_max * crop_height + y_offset) |
|
|
|
adjusted_norm_x_min = abs_x_min / width |
|
adjusted_norm_x_max = abs_x_max / width |
|
adjusted_norm_y_min = abs_y_min / height |
|
adjusted_norm_y_max = abs_y_max / height |
|
|
|
return (adjusted_norm_x_min, adjusted_norm_y_min, adjusted_norm_x_max, adjusted_norm_y_max) |
|
|
|
def __call__( |
|
self, |
|
images: ImageInput = None, |
|
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, |
|
audio=None, |
|
videos=None, |
|
**kwargs: Unpack[LlavaProcessorKwargs], |
|
) -> BatchFeature: |
|
""" |
|
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` |
|
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode |
|
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to |
|
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring |
|
of the above two methods for more information. |
|
|
|
Args: |
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): |
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch |
|
tensor. Both channels-first and channels-last formats are supported. |
|
text (`str`, `List[str]`, `List[List[str]]`): |
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings |
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set |
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). |
|
return_tensors (`str` or [`~utils.TensorType`], *optional*): |
|
If set, will return tensors of a particular framework. Acceptable values are: |
|
- `'tf'`: Return TensorFlow `tf.constant` objects. |
|
- `'pt'`: Return PyTorch `torch.Tensor` objects. |
|
- `'np'`: Return NumPy `np.ndarray` objects. |
|
- `'jax'`: Return JAX `jnp.ndarray` objects. |
|
|
|
Returns: |
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields: |
|
|
|
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. |
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when |
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not |
|
`None`). |
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. |
|
""" |
|
if images is None and text is None: |
|
raise ValueError("You have to specify at least one of `images` or `text`.") |
|
|
|
|
|
images, text = _validate_images_text_input_order(images, text) |
|
|
|
output_kwargs = self._merge_kwargs( |
|
LlavaProcessorKwargs, |
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
|
**kwargs, |
|
) |
|
if images is not None: |
|
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) |
|
else: |
|
image_inputs = {} |
|
|
|
if isinstance(text, str): |
|
text = [text] |
|
elif not isinstance(text, list) and not isinstance(text[0], str): |
|
raise ValueError("Invalid input text. Please provide a string, or a list of strings") |
|
|
|
|
|
prompt_strings = text |
|
if image_inputs.get("pixel_values") is not None: |
|
if self.patch_size is not None and self.vision_feature_select_strategy is not None: |
|
|
|
pixel_values = image_inputs["pixel_values"] |
|
height, width = get_image_size(to_numpy_array(pixel_values[0])) |
|
num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1 |
|
if self.vision_feature_select_strategy == "default": |
|
num_image_tokens -= 1 |
|
|
|
prompt_strings = [] |
|
for sample in text: |
|
sample = sample.replace(self.image_token, self.image_token * num_image_tokens) |
|
prompt_strings.append(sample) |
|
|
|
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) |
|
return BatchFeature(data={**text_inputs, **image_inputs}) |