Diffusers
TalHach61 commited on
Commit
d12d68f
·
verified ·
1 Parent(s): 59bcbc6

Delete bria_utils.py

Browse files
Files changed (1) hide show
  1. bria_utils.py +0 -71
bria_utils.py DELETED
@@ -1,71 +0,0 @@
1
- from typing import Union, Optional, List
2
- import torch
3
- from diffusers.utils import logging
4
- from transformers import (
5
- T5EncoderModel,
6
- T5TokenizerFast,
7
- )
8
- import numpy as np
9
-
10
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
11
-
12
- def get_t5_prompt_embeds(
13
- tokenizer: T5TokenizerFast ,
14
- text_encoder: T5EncoderModel,
15
- prompt: Union[str, List[str]] = None,
16
- num_images_per_prompt: int = 1,
17
- max_sequence_length: int = 128,
18
- device: Optional[torch.device] = None,
19
- ):
20
- device = device or text_encoder.device
21
-
22
- prompt = [prompt] if isinstance(prompt, str) else prompt
23
- batch_size = len(prompt)
24
-
25
- text_inputs = tokenizer(
26
- prompt,
27
- # padding="max_length",
28
- max_length=max_sequence_length,
29
- truncation=True,
30
- add_special_tokens=True,
31
- return_tensors="pt",
32
- )
33
- text_input_ids = text_inputs.input_ids
34
- untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
35
-
36
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
37
- removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
38
- logger.warning(
39
- "The following part of your input was truncated because `max_sequence_length` is set to "
40
- f" {max_sequence_length} tokens: {removed_text}"
41
- )
42
-
43
- prompt_embeds = text_encoder(text_input_ids.to(device))[0]
44
-
45
- # Concat zeros to max_sequence
46
- b, seq_len, dim = prompt_embeds.shape
47
- if seq_len<max_sequence_length:
48
- padding = torch.zeros((b,max_sequence_length-seq_len,dim),dtype=prompt_embeds.dtype,device=prompt_embeds.device)
49
- prompt_embeds = torch.concat([prompt_embeds,padding],dim=1)
50
-
51
- prompt_embeds = prompt_embeds.to(device=device)
52
-
53
- _, seq_len, _ = prompt_embeds.shape
54
-
55
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
56
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
57
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
58
-
59
- return prompt_embeds
60
-
61
- # in order the get the same sigmas as in training and sample from them
62
- def get_original_sigmas(num_train_timesteps=1000,num_inference_steps=1000):
63
- timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
64
- sigmas = timesteps / num_train_timesteps
65
-
66
- inds = [int(ind) for ind in np.linspace(0, num_train_timesteps-1, num_inference_steps)]
67
- new_sigmas = sigmas[inds]
68
- return new_sigmas
69
-
70
- def is_ng_none(negative_prompt):
71
- return negative_prompt is None or negative_prompt=='' or (isinstance(negative_prompt,list) and negative_prompt[0] is None) or (type(negative_prompt)==list and negative_prompt[0]=='')