|
import gradio as gr
|
|
import numpy as np
|
|
import pandas as pd
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import os
|
|
|
|
kaomojis = [
|
|
"0_0",
|
|
"(o)_(o)",
|
|
"+_+",
|
|
"+_-",
|
|
"._.",
|
|
"<o>_<o>",
|
|
"<|>_<|>",
|
|
"=_=",
|
|
">_<",
|
|
"3_3",
|
|
"6_9",
|
|
">_o",
|
|
"@_@",
|
|
"^_^",
|
|
"o_o",
|
|
"u_u",
|
|
"x_x",
|
|
"|_|",
|
|
"||_||",
|
|
]
|
|
|
|
index_file = './caption_index.parquet'
|
|
|
|
|
|
df = pd.read_parquet(index_file)
|
|
|
|
|
|
def process_input(user_input):
|
|
user_tags = set(tag.replace(' ', '_') for tag in user_input.split(', '))
|
|
|
|
def match_tags(caption, tags):
|
|
caption_set = set(caption.split(', '))
|
|
return tags.issubset(caption_set)
|
|
|
|
def process_chunk(chunk):
|
|
chunk = chunk.copy()
|
|
chunk['match'] = chunk.index.to_series().apply(lambda x: match_tags(x, user_tags))
|
|
return chunk[chunk['match']]
|
|
|
|
chunk_size = 100000
|
|
chunks = [df.iloc[i:i + chunk_size] for i in range(0, df.shape[0], chunk_size)]
|
|
|
|
with ThreadPoolExecutor(max_workers=8) as executor:
|
|
results = executor.map(process_chunk, chunks)
|
|
|
|
filtered_df = pd.concat(results)
|
|
|
|
def calculate_weight(score):
|
|
try:
|
|
weight = float(score) - 5
|
|
return max(weight, 0)
|
|
except ValueError:
|
|
return 0
|
|
|
|
filtered_df['weight'] = filtered_df['score'].apply(calculate_weight)
|
|
|
|
random_seed = np.random.randint(0, 1000000)
|
|
np.random.seed(random_seed)
|
|
|
|
sample_size = min(5, len(filtered_df))
|
|
|
|
if sample_size > 0:
|
|
weights = filtered_df['weight'].to_numpy()
|
|
weights /= weights.sum()
|
|
sampled_indices = np.random.choice(filtered_df.index, size=sample_size, p=weights, replace=False)
|
|
sampled_df = filtered_df.loc[sampled_indices]
|
|
else:
|
|
sampled_df = filtered_df
|
|
|
|
output = []
|
|
for index, row in sampled_df.iterrows():
|
|
tags = index.split(', ')
|
|
processed_tags = [tag.replace('_', ' ') if tag not in kaomojis else tag for tag in tags]
|
|
processed_tags = [tag.replace("(", "\(").replace(")", "\)") for tag in processed_tags]
|
|
processed_caption = ', '.join(processed_tags)
|
|
row['name'] = row['name'].replace('danbooru_', 'https://danbooru.donmai.us/posts/')
|
|
output.append(f"<a href='{row['name']}' target='_blank'>{row['name']}</a>: {processed_caption}<br>")
|
|
|
|
return ''.join(output), len(filtered_df)
|
|
|
|
iface = gr.Interface(
|
|
fn=process_input,
|
|
inputs=gr.Textbox(label="Input tags separated by ', '"),
|
|
outputs=[
|
|
gr.HTML(),
|
|
gr.Number(label="Matched Images Count")
|
|
],
|
|
title="Prompt Sampling",
|
|
flagging_mode='never'
|
|
)
|
|
iface.launch() |