File size: 5,005 Bytes
5c4ad21
 
 
 
f1148e7
5c4ad21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import streamlit as st
import pandas as pd
import time
import json
import os
import plotly.graph_objects as go

st.set_page_config(layout="wide")

@st.cache_resource
def load_and_preprocess_data():
    start_time = time.time()
    df = pd.read_parquet(os.getenv('PARQUET_FILE'))
    df = df.sort_values(by='post_id', ascending=False)
    df["tags"] = df["tags"].apply(lambda x: set(x))
    df.set_index('post_id', inplace=True)

    sorted_indices = {
        'Post ID (Descending)': df.index,
        'Post ID (Ascending)': df.index[::-1],
        'Clip Score': df['clip_aesthetic'].sort_values(ascending=False).index,
        'Siglip Score': df['clip_aesthetic_2_5'].sort_values(ascending=False).index,
    }
    print(f"Data loaded and preprocessed: {time.time() - start_time:.2f} seconds")
    return df, sorted_indices

st.title('Danbooru Images')
data, sorted_indices = load_and_preprocess_data()

# isdebar
st.sidebar.header('Filter Options')
st.sidebar.write('Adjust the filter options to refine the results.')
score_range = st.sidebar.slider('Select clip score range', min_value=0.0, max_value=10.0, value=(0.0, 10.0), step=0.1, help='Filter images based on their CLIP score range.')
score_range_v2 = st.sidebar.slider('Select siglip score range', min_value=0.0, max_value=10.0, value=(6.0, 10.0), step=0.1, help='Filter images based on their SigLIP score range.')
page_number = st.sidebar.number_input('Page', min_value=1, value=1, step=1, help='Navigate through the pages of filtered results.')
sort_option = st.sidebar.selectbox('Sort by (slow)', options=['Post ID (Descending)', 'Post ID (Ascending)', 'Clip Score', 'Siglip Score'], help='Select sorting option for the results.')

# user input
user_input_tags = st.text_input('Enter tags (space-separated)', help='Filter images based on tags. Use "-" to exclude tags.')   
selected_tags = set([tag.strip() for tag in user_input_tags.split() if tag.strip() and not tag.strip().startswith('-')])
undesired_tags = set([tag[1:] for tag in  user_input_tags.split() if tag.startswith('-')])
print(f"Selected tags: {selected_tags}, Undesired tags: {undesired_tags}")

# Function to filter data based on user input
def filter_data(df, score_range, score_range_v2, selected_tags, sort_option):
    start_time = time.time()
    
    filtered_data = df[
        (df['clip_aesthetic'] >= score_range[0]) & 
        (df['clip_aesthetic'] <= score_range[1]) &
        (df['clip_aesthetic_2_5'] >= score_range_v2[0]) & 
        (df['clip_aesthetic_2_5'] <= score_range_v2[1])
    ]
    print(f"Data filtered based on scores: {time.time() - start_time:.2f} seconds")
    
    if sort_option != "Post ID (Descending)":
        sorted_index = sorted_indices[sort_option]
        sorted_index = sorted_index[sorted_index.isin(filtered_data.index)]
        filtered_data = filtered_data.loc[sorted_index]
        print(f"Applying indcies: {time.time() - start_time:.2f} seconds")
    
    if selected_tags or undesired_tags:
        filtered_data = filtered_data[filtered_data['tags'].apply(lambda x: selected_tags.issubset(x) and not undesired_tags.intersection(x))]
    
    print(f"Data filtered: {time.time() - start_time:.2f} seconds")
    return filtered_data

# Filter data
filtered_data = filter_data(data, score_range, score_range_v2, selected_tags, sort_option)
st.sidebar.write(f"Total filtered images: {len(filtered_data)}")

# Pagination 
items_per_page = 30
start_idx = (page_number - 1) * items_per_page
end_idx = start_idx + items_per_page
current_data = filtered_data.iloc[start_idx:end_idx]

# Display the data
columns_per_row = 5
rows = [current_data.iloc[i:i + columns_per_row] for i in range(0, len(current_data), columns_per_row)]
for row in rows:
    cols = st.columns(columns_per_row)
    for col, (_, row_data) in zip(cols, row.iterrows()):
        with col:
            st.image(row_data['large_file_url'], caption=f"ID: {row_data.name}, CLIP: {row_data['clip_aesthetic']:.2f}, SigLIP: {row_data['clip_aesthetic_2_5']:.2f}", use_column_width=True)
            # st.markdown("<div style='margin: 2px;'></div>", unsafe_allow_html=True)
            
def histogram_slider(df, column1, column2):
    sample_data = df.sample(min(5000, len(df)))
    
    fig = go.Figure()
    fig.add_trace(go.Histogram(x=sample_data[column1], nbinsx=50, name=column1, opacity=0.75))
    fig.add_trace(go.Histogram(x=sample_data[column2], nbinsx=50, name=column2, opacity=0.75))
    fig.update_layout(
        barmode='overlay',
        bargap=0.1,
        height=200,
        xaxis=dict(showticklabels=True),
        yaxis=dict(showticklabels=True),
        margin=dict(l=0, r=0, t=0, b=0),
        legend=dict(orientation='h', yanchor='bottom', y=-0.4, xanchor='center', x=0.5),
    )
    st.sidebar.plotly_chart(fig, use_container_width=True, config={'displayModeBar': False})

# histogram
start_time = time.time()
histogram_slider(filtered_data, 'clip_aesthetic', 'clip_aesthetic_2_5')
print(f"Histogram displayed: {time.time() - start_time:.2f} seconds")