123123 / app.py
KAZEKOI's picture
Upload 2 files
7bb21b9 verified
import gradio as gr
import numpy as np
import pandas as pd
from concurrent.futures import ThreadPoolExecutor
import os
kaomojis = [
"0_0",
"(o)_(o)",
"+_+",
"+_-",
"._.",
"<o>_<o>",
"<|>_<|>",
"=_=",
">_<",
"3_3",
"6_9",
">_o",
"@_@",
"^_^",
"o_o",
"u_u",
"x_x",
"|_|",
"||_||",
]
index_file = './caption_index.parquet'
df = pd.read_parquet(index_file)
def process_input(user_input):
user_tags = set(tag.replace(' ', '_') for tag in user_input.split(', '))
def match_tags(caption, tags):
caption_set = set(caption.split(', '))
return tags.issubset(caption_set)
def process_chunk(chunk):
chunk = chunk.copy()
chunk['match'] = chunk.index.to_series().apply(lambda x: match_tags(x, user_tags))
return chunk[chunk['match']]
chunk_size = 100000
chunks = [df.iloc[i:i + chunk_size] for i in range(0, df.shape[0], chunk_size)]
with ThreadPoolExecutor(max_workers=8) as executor:
results = executor.map(process_chunk, chunks)
filtered_df = pd.concat(results)
def calculate_weight(score):
try:
weight = float(score) - 5
return max(weight, 0)
except ValueError:
return 0
filtered_df['weight'] = filtered_df['score'].apply(calculate_weight)
random_seed = np.random.randint(0, 1000000)
np.random.seed(random_seed)
sample_size = min(5, len(filtered_df))
if sample_size > 0:
weights = filtered_df['weight'].to_numpy()
weights /= weights.sum()
sampled_indices = np.random.choice(filtered_df.index, size=sample_size, p=weights, replace=False)
sampled_df = filtered_df.loc[sampled_indices]
else:
sampled_df = filtered_df
output = []
for index, row in sampled_df.iterrows():
tags = index.split(', ')
processed_tags = [tag.replace('_', ' ') if tag not in kaomojis else tag for tag in tags]
processed_tags = [tag.replace("(", "\(").replace(")", "\)") for tag in processed_tags]
processed_caption = ', '.join(processed_tags)
row['name'] = row['name'].replace('danbooru_', 'https://danbooru.donmai.us/posts/')
output.append(f"<a href='{row['name']}' target='_blank'>{row['name']}</a>: {processed_caption}<br>")
return ''.join(output), len(filtered_df)
iface = gr.Interface(
fn=process_input,
inputs=gr.Textbox(label="Input tags separated by ', '"),
outputs=[
gr.HTML(),
gr.Number(label="Matched Images Count")
],
title="Prompt Sampling",
flagging_mode='never'
)
iface.launch()