File size: 2,746 Bytes
7bb21b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
import numpy as np
import pandas as pd
from concurrent.futures import ThreadPoolExecutor
import os

kaomojis = [
    "0_0",
    "(o)_(o)",
    "+_+",
    "+_-",
    "._.",
    "<o>_<o>",
    "<|>_<|>",
    "=_=",
    ">_<",
    "3_3",
    "6_9",
    ">_o",
    "@_@",
    "^_^",
    "o_o",
    "u_u",
    "x_x",
    "|_|",
    "||_||",
]

index_file = './caption_index.parquet'


df = pd.read_parquet(index_file)


def process_input(user_input):
    user_tags = set(tag.replace(' ', '_') for tag in user_input.split(', '))

    def match_tags(caption, tags):
        caption_set = set(caption.split(', '))
        return tags.issubset(caption_set)

    def process_chunk(chunk):
        chunk = chunk.copy()
        chunk['match'] = chunk.index.to_series().apply(lambda x: match_tags(x, user_tags))
        return chunk[chunk['match']]

    chunk_size = 100000
    chunks = [df.iloc[i:i + chunk_size] for i in range(0, df.shape[0], chunk_size)]

    with ThreadPoolExecutor(max_workers=8) as executor:
        results = executor.map(process_chunk, chunks)

    filtered_df = pd.concat(results)

    def calculate_weight(score):
        try:
            weight = float(score) - 5
            return max(weight, 0)
        except ValueError:
            return 0

    filtered_df['weight'] = filtered_df['score'].apply(calculate_weight)

    random_seed = np.random.randint(0, 1000000)
    np.random.seed(random_seed)

    sample_size = min(5, len(filtered_df))

    if sample_size > 0:
        weights = filtered_df['weight'].to_numpy()
        weights /= weights.sum()
        sampled_indices = np.random.choice(filtered_df.index, size=sample_size, p=weights, replace=False)
        sampled_df = filtered_df.loc[sampled_indices]
    else:
        sampled_df = filtered_df

    output = []
    for index, row in sampled_df.iterrows():
        tags = index.split(', ')
        processed_tags = [tag.replace('_', ' ') if tag not in kaomojis else tag for tag in tags]
        processed_tags = [tag.replace("(", "\(").replace(")", "\)") for tag in processed_tags]
        processed_caption = ', '.join(processed_tags)
        row['name'] = row['name'].replace('danbooru_', 'https://danbooru.donmai.us/posts/')
        output.append(f"<a href='{row['name']}' target='_blank'>{row['name']}</a>: {processed_caption}<br>")

    return ''.join(output), len(filtered_df)

iface = gr.Interface(
    fn=process_input, 
    inputs=gr.Textbox(label="Input tags separated by ', '"), 
    outputs=[
        gr.HTML(),
        gr.Number(label="Matched Images Count")
    ], 
    title="Prompt Sampling", 
    flagging_mode='never'
)
iface.launch()