Diffusers
File size: 2,819 Bytes
040ea64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from typing import Union, Optional, List
import torch
from diffusers.utils import logging
from transformers import (
    T5EncoderModel,
    T5TokenizerFast,
)
import numpy as np

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

def get_t5_prompt_embeds(
    tokenizer: T5TokenizerFast ,
    text_encoder: T5EncoderModel,
    prompt: Union[str, List[str]] = None,
    num_images_per_prompt: int = 1,
    max_sequence_length: int = 128,
    device: Optional[torch.device] = None,
):
    device = device or text_encoder.device
    
    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    text_inputs = tokenizer(
        prompt,
        # padding="max_length",
        max_length=max_sequence_length,
        truncation=True,
        add_special_tokens=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

    if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
        removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
        logger.warning(
            "The following part of your input was truncated because `max_sequence_length` is set to "
            f" {max_sequence_length} tokens: {removed_text}"
        )

    prompt_embeds = text_encoder(text_input_ids.to(device))[0]
    
    # Concat zeros to max_sequence
    b, seq_len, dim = prompt_embeds.shape
    if seq_len<max_sequence_length:
        padding = torch.zeros((b,max_sequence_length-seq_len,dim),dtype=prompt_embeds.dtype,device=prompt_embeds.device)
        prompt_embeds = torch.concat([prompt_embeds,padding],dim=1)

    prompt_embeds = prompt_embeds.to(device=device)

    _, seq_len, _ = prompt_embeds.shape

    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    return prompt_embeds

# in order the get the same sigmas as in training and sample from them
def get_original_sigmas(num_train_timesteps=1000,num_inference_steps=1000):
    timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
    sigmas = timesteps / num_train_timesteps

    inds = [int(ind) for ind in  np.linspace(0, num_train_timesteps-1, num_inference_steps)]
    new_sigmas = sigmas[inds]
    return new_sigmas

def is_ng_none(negative_prompt):
    return negative_prompt is None  or negative_prompt=='' or (isinstance(negative_prompt,list) and negative_prompt[0] is None) or (type(negative_prompt)==list and negative_prompt[0]=='')