Spaces:
Runtime error
Runtime error
File size: 4,676 Bytes
113c29e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# Fooocus GPT2 Expansion
# Algorithm created by Lvmin Zhang at 2023, Stanford
# If used inside Fooocus, any use is permitted.
# If used outside Fooocus, only non-commercial use is permitted (CC-By NC 4.0).
# This applies to the word list, vocab, model, and algorithm.
import os
import torch
import math
import ldm_patched.modules.model_management as model_management
from transformers.generation.logits_process import LogitsProcessorList
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from modules.config import path_fooocus_expansion
from ldm_patched.modules.model_patcher import ModelPatcher
# limitation of np.random.seed(), called from transformers.set_seed()
SEED_LIMIT_NUMPY = 2**32
neg_inf = - 8192.0
def safe_str(x):
x = str(x)
for _ in range(16):
x = x.replace(' ', ' ')
return x.strip(",. \r\n")
def remove_pattern(x, pattern):
for p in pattern:
x = x.replace(p, '')
return x
class FooocusExpansion:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained(path_fooocus_expansion)
positive_words = open(os.path.join(path_fooocus_expansion, 'positive.txt'),
encoding='utf-8').read().splitlines()
positive_words = ['Ġ' + x.lower() for x in positive_words if x != '']
self.logits_bias = torch.zeros((1, len(self.tokenizer.vocab)), dtype=torch.float32) + neg_inf
debug_list = []
for k, v in self.tokenizer.vocab.items():
if k in positive_words:
self.logits_bias[0, v] = 0
debug_list.append(k[1:])
print(f'Fooocus V2 Expansion: Vocab with {len(debug_list)} words.')
# debug_list = '\n'.join(sorted(debug_list))
# print(debug_list)
# t11 = self.tokenizer(',', return_tensors="np")
# t198 = self.tokenizer('\n', return_tensors="np")
# eos = self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(path_fooocus_expansion)
self.model.eval()
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
# MPS hack
if model_management.is_device_mps(load_device):
load_device = torch.device('cpu')
offload_device = torch.device('cpu')
use_fp16 = model_management.should_use_fp16(device=load_device)
if use_fp16:
self.model.half()
self.patcher = ModelPatcher(self.model, load_device=load_device, offload_device=offload_device)
print(f'Fooocus Expansion engine loaded for {load_device}, use_fp16 = {use_fp16}.')
@torch.no_grad()
@torch.inference_mode()
def logits_processor(self, input_ids, scores):
assert scores.ndim == 2 and scores.shape[0] == 1
self.logits_bias = self.logits_bias.to(scores)
bias = self.logits_bias.clone()
bias[0, input_ids[0].to(bias.device).long()] = neg_inf
bias[0, 11] = 0
return scores + bias
@torch.no_grad()
@torch.inference_mode()
def __call__(self, prompt, seed):
if prompt == '':
return ''
if self.patcher.current_device != self.patcher.load_device:
print('Fooocus Expansion loaded by itself.')
model_management.load_model_gpu(self.patcher)
seed = int(seed) % SEED_LIMIT_NUMPY
set_seed(seed)
prompt = safe_str(prompt) + ','
tokenized_kwargs = self.tokenizer(prompt, return_tensors="pt")
tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(self.patcher.load_device)
tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(self.patcher.load_device)
current_token_length = int(tokenized_kwargs.data['input_ids'].shape[1])
max_token_length = 75 * int(math.ceil(float(current_token_length) / 75.0))
max_new_tokens = max_token_length - current_token_length
if max_new_tokens == 0:
return prompt[:-1]
# https://huggingface.co/blog/introducing-csearch
# https://huggingface.co/docs/transformers/generation_strategies
features = self.model.generate(**tokenized_kwargs,
top_k=100,
max_new_tokens=max_new_tokens,
do_sample=True,
logits_processor=LogitsProcessorList([self.logits_processor]))
response = self.tokenizer.batch_decode(features, skip_special_tokens=True)
result = safe_str(response[0])
return result
|