mutimodal / app.py
warhawkmonk's picture
Update app.py
10719a6 verified
raw
history blame
24.3 kB
import pandas as pd
from PIL import Image
import streamlit as st
import cv2
from streamlit_drawable_canvas import st_canvas
import torch
from diffusers import AutoPipelineForInpainting
import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from sentence_transformers import SentenceTransformer,util
from streamlit_image_select import image_select
import os
import fitz
import PyPDF2
import requests
from streamlit_navigation_bar import st_navbar
from langchain_community.llms import Ollama
import base64
from io import BytesIO
from PIL import Image, ImageDraw
from streamlit_lottie import st_lottie
from streamlit_option_menu import option_menu
import json
from transformers import pipeline
import streamlit as st
from streamlit_modal import Modal
import streamlit.components.v1 as components
from datetime import datetime
from streamlit_js_eval import streamlit_js_eval
from streamlit_pdf_viewer import pdf_viewer
def consume_llm_api(prompt):
"""
Sends a prompt to the LLM API and processes the streamed response.
"""
url = "https://wise-eagles-send.loca.lt/api/llm-response"
headers = {"Content-Type": "application/json"}
payload = {"prompt": prompt}
try:
print("Sending prompt to the LLM API...")
with requests.post(url, json=payload, headers=headers, stream=True) as response:
response.raise_for_status()
print("Response from LLM API:\n")
for line in response:
yield(line.decode('utf-8'))
# print(type(response))
# yield(response)
except requests.RequestException as e:
print(f"Error consuming API: {e}")
except Exception as e:
print(f"Unexpected error: {e}")
def send_prompt():
return "please respond according to the prompt asked below from the above context"
def image_to_base64(image_path):
with open(image_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode()
@st.cache_resource
def load_model():
pipeline_ = AutoPipelineForInpainting.from_pretrained("kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16).to("cuda")
return pipeline_
# @st.cache_resource
def prompt_improvment(pre_prompt):
enhancement="Please use details from the prompt mentioned above, focusing only what user is thinking with the prompt and also add 8k resolution. Its a request only provide image description and brief prompt no other text."
prompt = pre_prompt+"\n"+enhancement
return consume_llm_api(prompt)
def process_pdf(file):
documents = []
with open(file, "rb") as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
text = page.extract_text()
if text: # Ensure that the page has text
documents.append(Document(page_content=text))
return documents
def numpy_to_list(array):
current=[]
for value in array:
if isinstance(value,type(np.array([]))):
result=numpy_to_list(value)
current.append(result)
else:
current.append(int(value))
return current
@st.cache_resource
def llm_text_response():
llm = Ollama(model="llama3:latest",num_ctx=1000)
return llm.stream
def model_single_out(prompt):
pipe=load_model()
image = pipe(prompt).images[0]
return image
def model_out_put(init_image,mask_image,prompt,negative_prompt):
API_URL = "https://7716-205-196-17-124.ngrok-free.app/api/llm-response"
initial_image_base64 = numpy_to_list(np.array(init_image))
mask_image_base64 = numpy_to_list(np.array(mask_image))
payload = {
"prompt": prompt, # Replace with your desired prompt
"initial_img": initial_image_base64,
"masked_img": mask_image_base64,
"negative_prompt": negative_prompt # Replace with your negative prompt
}
response_ = requests.post(API_URL, json=payload)
response_data = response_.json()
output_image_base64 = response_data.get("img", "")
output_image=np.array(output_image_base64,dtype=np.uint8)
output_image = Image.fromarray(output_image)
# output_image.show()
return output_image
@st.cache_resource
def multimodel():
pipeline_ = pipeline("text-classification", model = "/home/user/app/model_path/")
return pipeline_
def multimodel_output(prompt):
pipeline_ = multimodel()
image = pipeline_(prompt)
return image[0]['label']
def d4_to_3d(image):
formatted_array=[]
for j in image:
neste_list=[]
for k in j:
if any([True if i>0 else False for i in k]):
neste_list.append(True)
else:
neste_list.append(False)
formatted_array.append(neste_list)
print(np.shape(formatted_array))
return np.array(formatted_array)
st.set_page_config(layout="wide")
# st.write(str(os.getcwd()))
screen_width = streamlit_js_eval(label="screen.width",js_expressions='screen.width')
screen_height = streamlit_js_eval(label="screen.height",js_expressions='screen.height')
img_selection=None
# Specify canvas parameters in application
drawing_mode = st.sidebar.selectbox(
"Drawing tool:", ("freedraw","point", "line", "rect", "circle", "transform")
)
dictionary=st.session_state
if "every_prompt_with_val" not in dictionary:
dictionary['every_prompt_with_val']=[]
if "current_image" not in dictionary:
dictionary['current_image']=[]
if "prompt_collection" not in dictionary:
dictionary['prompt_collection']=[]
if "user" not in dictionary:
dictionary['user']=None
if "current_session" not in dictionary:
dictionary['current_session']=None
if "image_movement" not in dictionary:
dictionary['image_movement']=None
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 20)
if drawing_mode == 'point':
point_display_radius = st.sidebar.slider("Point display radius: ", 1, 25, 3)
stroke_color = '#000000'
bg_color = "#eee"
column1,column2=st.columns([0.7,0.35])
with open("/home/user/app/DataBase/datetimeRecords.json","r") as read:
dateTimeRecord=json.load(read)
with column2:
st.header("HISTORY")
tab1,tab2,tab3,tab4=st.tabs(["CHAT HISTORY","IMAGES","PROMPT IMPROVEMENT","LOGIN"])
with tab1:
if not len(dictionary['every_prompt_with_val']):
st.header("I will store all the chat for the current session")
with open("/home/user/app/lotte_animation_saver/animation_4.json") as read:
url_json=json.load(read)
st_lottie(url_json,height = 400)
else:
with st.container(height=600):
for index,prompts_ in enumerate(dictionary['every_prompt_with_val'][::-1]):
if prompts_[-1]=="@working":
if index==0:
st.write(prompts_[0].split(send_prompt())[-1].upper() if send_prompt() in prompts_[0] else prompts_[0].upper())
data_need=st.write_stream(consume_llm_api(prompts_[0]))
dictionary['every_prompt_with_val'][-1]=(prompts_[0],str(data_need))
elif isinstance(prompts_[-1],str):
show_case_text=prompts_[0].split(send_prompt())[-1].upper() if send_prompt() in prompts_[0] else prompts_[0].upper()
if index==0:
st.text_area(label=show_case_text,value=prompts_[-1],height=500,key=str(index))
else:
st.text_area(label=show_case_text,value=prompts_[-1],key=str(index))
else:
st.write(prompts_[0].upper())
with st.container(height=400):
format1,format2=st.columns([0.2,0.8])
with format1:
new_img=Image.open("/home/user/app/ALL_image_formation/image_gen.png")
st.write("<br>",unsafe_allow_html=True)
size = min(new_img.size)
mask = Image.new('L', (size, size), 0)
draw = ImageDraw.Draw(mask)
draw.ellipse((0, 0, size, size), fill=255)
image = new_img.crop((0, 0, size, size))
image.putalpha(mask)
st.image(image)
with format2:
st.write("<br>",unsafe_allow_html=True)
size = min(prompts_[-1].size)
mask = Image.new('L', (size, size), 0)
draw = ImageDraw.Draw(mask)
draw.ellipse((0, 0, size, size), fill=255)
# Crop the image to a square and apply the mask
image = prompts_[-1].crop((0, 0, size, size))
image.putalpha(mask)
st.image(image)
with tab2:
if "current_image" in dictionary and len(dictionary['current_image']):
with st.container(height=600):
dictinory_length=len(dictionary['current_image'])
img_selection = image_select(
label="",
images=dictionary['current_image'] if len(dictionary['current_image'])!=0 else None,
)
if img_selection in dictionary['current_image']:
dictionary['current_image'].remove(img_selection)
dictionary['current_image'].insert(0,img_selection)
if dictionary['image_movement']!=img_selection:
dictionary['image_movement']=img_selection
st.rerun() # st.rerun()
img_selection.save("image.png")
with open("image.png", "rb") as file:
downl=st.download_button(label="DOWNLOAD",data=file,file_name="image.png",mime="image/png")
os.remove("image.png")
else:
st.header("This section will store the updated images")
with open("/home/user/app/lotte_animation_saver/animation_1.json") as read:
url_json=json.load(read)
st_lottie(url_json,height = 400)
with tab3:
if len(dictionary['prompt_collection'])!=0:
with st.container(height=600):
prompt_selection=st.selectbox(label="Select the prompt for improvment",options=["Mention below are prompt history"]+dictionary["prompt_collection"],index=0)
if prompt_selection!="Mention below are prompt history":
generated_prompt=prompt_improvment(prompt_selection)
dictionary['generated_image_prompt'].append(generated_prompt)
st.write_stream(generated_prompt)
else:
st.header("This section will provide prompt improvement section")
with open("/home/user/app/lotte_animation_saver/animation_3.json") as read:
url_json=json.load(read)
st_lottie(url_json,height = 400)
with tab4:
# with st.container(height=600):
if not dictionary['user'] :
with st.form("my_form"):
# st.header("Please login for save your data")
with open("/home/user/app/lotte_animation_saver/animation_5.json") as read:
url_json=json.load(read)
st_lottie(url_json,height = 200)
user_id = st.text_input("user login")
password = st.text_input("password",type="password")
submitted_login = st.form_submit_button("Submit")
# Every form must have a submit button.
if submitted_login:
with open("/home/user/app/DataBase/login.json","r") as read:
login_base=json.load(read)
if user_id in login_base and login_base[user_id]==password:
dictionary['user']=user_id
st.rerun()
else:
st.error("userid or password incorrect")
st.write("working")
modal = Modal(
"Sign up",
key="demo-modal",
padding=10, # default value
max_width=600 # default value
)
open_modal = st.button("sign up")
if open_modal:
modal.open()
if modal.is_open():
with modal.container():
with st.form("my_form1"):
sign_up_column_left,sign_up_column_right=st.columns(2)
with sign_up_column_left:
with open("/home/user/app/lotte_animation_saver/animation_6.json") as read:
url_json=json.load(read)
st_lottie(url_json,height = 200)
with sign_up_column_right:
user_id = st.text_input("user login")
password = st.text_input("password",type="password")
submitted_signup = st.form_submit_button("Submit")
if submitted_signup:
with open("/home/user/app/DataBase/login.json","r") as read:
login_base=json.load(read)
if not login_base:
login_base={}
if user_id not in login_base:
login_base[user_id]=password
with open("/home/user/app/DataBase/login.json","w") as write:
json.dump(login_base,write,indent=2)
st.success("you are a part now")
dictionary['user']=user_id
modal.close()
else:
st.error("user id already exists")
else:
st.header("REPORTED ISSUES")
with st.container(height=370):
with open("/home/user/app/DataBase/datetimeRecords.json") as feedback:
temp_issue=json.load(feedback)
arranged_feedback=reversed(temp_issue['database'])
for report in arranged_feedback:
user_columns,user_feedback=st.columns([0.3,0.8])
with user_columns:
st.write(report[-1])
with user_feedback:
st.write(report[1])
feedback=st.text_area("Feedback Report and Improvement",placeholder="")
summit=st.button("submit")
if summit:
with open("/home/user/app/DataBase/datetimeRecords.json","r") as feedback_sumit:
temp_issue_submit=json.load(feedback_sumit)
if "database" not in temp_issue_submit:
temp_issue_submit["database"]=[]
temp_issue_submit["database"].append((str(datetime.now()),feedback,dictionary['user']))
with open("/home/user/app/DataBase/datetimeRecords.json","w") as feedback_sumit:
json.dump(temp_issue_submit,feedback_sumit)
# st.rerun()
bg_image = st.sidebar.file_uploader("PLEASE UPLOAD IMAGE FOR EDITING:", type=["png", "jpg"])
bg_doc = st.sidebar.file_uploader("PLEASE UPLOAD DOC FOR PPT/PDF/STORY:", type=["pdf","xlsx"])
if "bg_image" not in dictionary:
dictionary["bg_image"]=None
if img_selection and dictionary['bg_image']==bg_image:
gen_image=dictionary['current_image'][0]
else:
if bg_image:
gen_image=Image.open(bg_image)
else:
gen_image=None
with column1:
# Create a canvas component
changes,implementation,current=st.columns([0.01,0.9,0.01])
with implementation:
st.write("<br>"*5,unsafe_allow_html=True)
if bg_doc:
canvas_result=None
binary_data = bg_doc.getvalue()
binary_data = base64.b64encode(bg_doc.getvalue()).decode('utf-8')
pdf_display = F'<embed class="pdfobject" type="application/pdf" title="Embedded PDF" src="data:application/pdf;base64,{binary_data}" width={screen_width//2.07} height={screen_height//1.83} type="application/pdf">'
st.markdown(pdf_display, unsafe_allow_html=True)
pdf_display = f"""<embed
class="pdfobject"
type="application/pdf"
title="Embedded PDF"
src="data:application/pdf;base64,{binary_data}"
style="overflow: auto; width: 100%; height: 100%;">"""
st.markdown(pdf_display, unsafe_allow_html=True)
with open("temp.pdf", "wb") as f:
f.write(bg_doc.getbuffer())
# Process the uploaded PDF file
data = process_pdf("temp.pdf")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=7500, chunk_overlap=100)
chunks = text_splitter.split_documents(data)
# chunk_texts = [str(chunk.page_content) for chunk in chunks]
# print("testing",chunk_texts)
model_name = "all-MiniLM-L6-v2"
model = SentenceTransformer(model_name)
embeddings = [model.encode(str(chunk.page_content)) for chunk in chunks]
vector_store = []
for chunk, embedding in zip(chunks, embeddings):
vector_store.append((embedding, chunk.page_content) )
else:
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=stroke_width,
stroke_color=stroke_color,
background_color=bg_color,
background_image=gen_image if gen_image else Image.open("/home/user/app/ALL_image_formation/image_gen.png"),
update_streamlit=True,
height=int(screen_height//2.16) if screen_height!=1180 else screen_height//2,
width=int(screen_width//2.3) if screen_width!=820 else screen_width//2,
drawing_mode=drawing_mode,
point_display_radius=point_display_radius if drawing_mode == 'point' else 0,
key="canvas",
)
with column1:
# prompt=st.text_area("Please provide the prompt")
prompt=st.chat_input("Please provide the prompt")
negative_prompt="the black masked area"
# run=st.button("run_experiment")
if bg_doc:
if len(dictionary['every_prompt_with_val'])==0:
query_embedding = model.encode(["something"])
else:
query_embedding = model.encode([dictionary['every_prompt_with_val'][-1][0]])
retrieved_chunks = max([(util.cos_sim(match[0],query_embedding),match[-1])for match in vector_store])[-1]
with implementation:
text_lookup=retrieved_chunks
pages=[]
with fitz.open("temp.pdf") as doc:
page_number = st.sidebar.number_input(
"Page number", min_value=1, max_value=doc.page_count, value=1, step=1
)
for page_no in range(doc.page_count):
pages.append(doc.load_page(page_no - 1))
# areas = pages[page_number-1].search_for(text_lookup)
with st.container(height=int(screen_height//1.8)):
for pg_no in pages[::-1]:
areas = pg_no.search_for(text_lookup)
for area in areas:
pg_no.add_rect_annot(area)
pix = pg_no.get_pixmap(dpi=100).tobytes()
st.image(pix,use_column_width=True)
if bg_doc and prompt:
query_embedding = model.encode([prompt])
retrieved_chunks = max([(util.cos_sim(match[0],query_embedding),match[-1])for match in vector_store])[-1]
print(retrieved_chunks)
prompt = "Context: "+ retrieved_chunks +"\n"+send_prompt()+ "\n"+prompt
modifiedValue="@working"
dictionary['every_prompt_with_val'].append((prompt,modifiedValue))
st.rerun()
elif not bg_doc and canvas_result.image_data is not None:
if prompt:
text_or_image=multimodel_output(prompt)
if text_or_image=="LABEL_0":
if "generated_image_prompt" not in dictionary:
dictionary['generated_image_prompt']=[]
if prompt not in dictionary['prompt_collection'] and prompt not in dictionary['generated_image_prompt']:
dictionary['prompt_collection']=[prompt]+dictionary['prompt_collection']
new_size=np.array(canvas_result.image_data).shape[:2]
new_size=(new_size[-1],new_size[0])
if bg_image!=dictionary["bg_image"] :
dictionary["bg_image"]=bg_image
if bg_image!=None:
imf=Image.open(bg_image).resize(new_size)
else:
with open("/home/user/app/lotte_animation_saver/animation_4.json") as read:
url_json=json.load(read)
st_lottie(url_json)
imf=Image.open("/home/user/app/ALL_image_formation/home_screen.jpg").resize(new_size)
else:
if len(dictionary['current_image'])!=0:
imf=dictionary['current_image'][0]
else:
with open("/home/user/app/lotte_animation_saver/animation_4.json") as read:
url_json=json.load(read)
st_lottie(url_json)
imf=Image.open("/home/user/app/ALL_image_formation/home_screen.jpg")
negative_image =d4_to_3d(np.array(canvas_result.image_data))
if np.sum(negative_image)==0:
negative_image=Image.fromarray(np.where(negative_image == False, True, negative_image))
else:
negative_image=Image.fromarray(negative_image)
modifiedValue=model_out_put(imf,negative_image,prompt,negative_prompt)
modifiedValue.save("/home/user/app/ALL_image_formation/current_session_image.png")
dictionary['current_image']=[modifiedValue]+dictionary['current_image']
dictionary['every_prompt_with_val'].append((prompt,modifiedValue))
st.rerun()
else:
st.write("nothing importent")
modifiedValue="@working"
dictionary['every_prompt_with_val'].append((prompt,modifiedValue))
st.rerun()
# st.image(modifiedValue,width=300)
# if canvas_result.json_data is not None:
# objects = pd.json_normalize(canvas_result.json_data["objects"]) # need to convert obj to str because PyArrow
# for col in objects.select_dtypes(include=['object']).columns:
# objects[col] = objects[col].astype("str")