Delete bria_utils.py
Browse files- bria_utils.py +0 -71
bria_utils.py
DELETED
@@ -1,71 +0,0 @@
|
|
1 |
-
from typing import Union, Optional, List
|
2 |
-
import torch
|
3 |
-
from diffusers.utils import logging
|
4 |
-
from transformers import (
|
5 |
-
T5EncoderModel,
|
6 |
-
T5TokenizerFast,
|
7 |
-
)
|
8 |
-
import numpy as np
|
9 |
-
|
10 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
11 |
-
|
12 |
-
def get_t5_prompt_embeds(
|
13 |
-
tokenizer: T5TokenizerFast ,
|
14 |
-
text_encoder: T5EncoderModel,
|
15 |
-
prompt: Union[str, List[str]] = None,
|
16 |
-
num_images_per_prompt: int = 1,
|
17 |
-
max_sequence_length: int = 128,
|
18 |
-
device: Optional[torch.device] = None,
|
19 |
-
):
|
20 |
-
device = device or text_encoder.device
|
21 |
-
|
22 |
-
prompt = [prompt] if isinstance(prompt, str) else prompt
|
23 |
-
batch_size = len(prompt)
|
24 |
-
|
25 |
-
text_inputs = tokenizer(
|
26 |
-
prompt,
|
27 |
-
# padding="max_length",
|
28 |
-
max_length=max_sequence_length,
|
29 |
-
truncation=True,
|
30 |
-
add_special_tokens=True,
|
31 |
-
return_tensors="pt",
|
32 |
-
)
|
33 |
-
text_input_ids = text_inputs.input_ids
|
34 |
-
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
35 |
-
|
36 |
-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
37 |
-
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
38 |
-
logger.warning(
|
39 |
-
"The following part of your input was truncated because `max_sequence_length` is set to "
|
40 |
-
f" {max_sequence_length} tokens: {removed_text}"
|
41 |
-
)
|
42 |
-
|
43 |
-
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
44 |
-
|
45 |
-
# Concat zeros to max_sequence
|
46 |
-
b, seq_len, dim = prompt_embeds.shape
|
47 |
-
if seq_len<max_sequence_length:
|
48 |
-
padding = torch.zeros((b,max_sequence_length-seq_len,dim),dtype=prompt_embeds.dtype,device=prompt_embeds.device)
|
49 |
-
prompt_embeds = torch.concat([prompt_embeds,padding],dim=1)
|
50 |
-
|
51 |
-
prompt_embeds = prompt_embeds.to(device=device)
|
52 |
-
|
53 |
-
_, seq_len, _ = prompt_embeds.shape
|
54 |
-
|
55 |
-
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
56 |
-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
57 |
-
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
58 |
-
|
59 |
-
return prompt_embeds
|
60 |
-
|
61 |
-
# in order the get the same sigmas as in training and sample from them
|
62 |
-
def get_original_sigmas(num_train_timesteps=1000,num_inference_steps=1000):
|
63 |
-
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
64 |
-
sigmas = timesteps / num_train_timesteps
|
65 |
-
|
66 |
-
inds = [int(ind) for ind in np.linspace(0, num_train_timesteps-1, num_inference_steps)]
|
67 |
-
new_sigmas = sigmas[inds]
|
68 |
-
return new_sigmas
|
69 |
-
|
70 |
-
def is_ng_none(negative_prompt):
|
71 |
-
return negative_prompt is None or negative_prompt=='' or (isinstance(negative_prompt,list) and negative_prompt[0] is None) or (type(negative_prompt)==list and negative_prompt[0]=='')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|