from typing import Union, Optional, List import torch from diffusers.utils import logging from transformers import ( T5EncoderModel, T5TokenizerFast, ) import numpy as np logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_t5_prompt_embeds( tokenizer: T5TokenizerFast , text_encoder: T5EncoderModel, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 128, device: Optional[torch.device] = None, ): device = device or text_encoder.device prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( prompt, # padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device))[0] # Concat zeros to max_sequence b, seq_len, dim = prompt_embeds.shape if seq_len