import gradio as gr import numpy as np import pandas as pd from concurrent.futures import ThreadPoolExecutor import os kaomojis = [ "0_0", "(o)_(o)", "+_+", "+_-", "._.", "_", "<|>_<|>", "=_=", ">_<", "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||", ] index_file = './caption_index.parquet' df = pd.read_parquet(index_file) def process_input(user_input): user_tags = set(tag.replace(' ', '_') for tag in user_input.split(', ')) def match_tags(caption, tags): caption_set = set(caption.split(', ')) return tags.issubset(caption_set) def process_chunk(chunk): chunk = chunk.copy() chunk['match'] = chunk.index.to_series().apply(lambda x: match_tags(x, user_tags)) return chunk[chunk['match']] chunk_size = 100000 chunks = [df.iloc[i:i + chunk_size] for i in range(0, df.shape[0], chunk_size)] with ThreadPoolExecutor(max_workers=8) as executor: results = executor.map(process_chunk, chunks) filtered_df = pd.concat(results) def calculate_weight(score): try: weight = float(score) - 5 return max(weight, 0) except ValueError: return 0 filtered_df['weight'] = filtered_df['score'].apply(calculate_weight) random_seed = np.random.randint(0, 1000000) np.random.seed(random_seed) sample_size = min(5, len(filtered_df)) if sample_size > 0: weights = filtered_df['weight'].to_numpy() weights /= weights.sum() sampled_indices = np.random.choice(filtered_df.index, size=sample_size, p=weights, replace=False) sampled_df = filtered_df.loc[sampled_indices] else: sampled_df = filtered_df output = [] for index, row in sampled_df.iterrows(): tags = index.split(', ') processed_tags = [tag.replace('_', ' ') if tag not in kaomojis else tag for tag in tags] processed_tags = [tag.replace("(", "\(").replace(")", "\)") for tag in processed_tags] processed_caption = ', '.join(processed_tags) row['name'] = row['name'].replace('danbooru_', 'https://danbooru.donmai.us/posts/') output.append(f"{row['name']}: {processed_caption}
") return ''.join(output), len(filtered_df) iface = gr.Interface( fn=process_input, inputs=gr.Textbox(label="Input tags separated by ', '"), outputs=[ gr.HTML(), gr.Number(label="Matched Images Count") ], title="Prompt Sampling", flagging_mode='never' ) iface.launch()