import json from transformers import AutoTokenizer, T5ForConditionalGeneration from diffusers.utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name def add_special_token(tokenizer, text_encoder, add_color, add_font, color_ann_path, font_ann_path, multilingual=False): with open(font_ann_path, 'r') as f: idx_font_dict = json.load(f) with open(color_ann_path, 'r') as f: idx_color_dict = json.load(f) if multilingual: font_token = [] for font_code in idx_font_dict: prefix = font_code[:2] font_token.append(f'<{prefix}-font-{idx_font_dict[font_code]}>') else: font_token = [f'' for i in range(len(idx_font_dict))] color_token = [f'' for i in range(len(idx_color_dict))] additional_special_tokens = [] if add_color: additional_special_tokens += color_token if add_font: additional_special_tokens += font_token tokenizer.add_tokens(additional_special_tokens, special_tokens=True) text_encoder.resize_token_embeddings(len(tokenizer)) def load_byt5_and_byt5_tokenizer( byt5_name='google/byt5-small', special_token=False, color_special_token=False, font_special_token=False, color_ann_path='assets/color_idx.json', font_ann_path='assets/font_idx_512.json', huggingface_cache_dir=None, multilingual=False, ): byt5_tokenizer = AutoTokenizer.from_pretrained( byt5_name, cache_dir=huggingface_cache_dir, ) byt5_text_encoder = T5ForConditionalGeneration.from_pretrained( byt5_name, cache_dir=huggingface_cache_dir, ).get_encoder() if special_token: add_special_token( byt5_tokenizer, byt5_text_encoder, add_color=color_special_token, add_font=font_special_token, color_ann_path=color_ann_path, font_ann_path=font_ann_path, multilingual=multilingual, ) logger.info(f'Loaded original byt5 weight') return byt5_text_encoder, byt5_tokenizer