sandrocalzada's picture
Update app.py
8006c7e
raw
history blame
2.67 kB
import streamlit as st
import face_recognition
import os
import cv2
import insightface
import pickle
from insightface.app import FaceAnalysis
from concurrent.futures import ThreadPoolExecutor
# Initialize your models only once
app = FaceAnalysis(name='buffalo_l')
app.prepare(ctx_id=0, det_size=(640, 640))
swapper = insightface.model_zoo.get_model(os.path.join(os.getcwd(), 'inswapper_128.onnx'), download=False)
# Load pickle file once and keep it in memory
@st.cache_data
def load_data_images():
with open(os.path.join(os.getcwd(), 'data_images.pkl'), 'rb') as file:
return pickle.load(file)
data_images = load_data_images()
def face_swapper(image_background, image_customer):
face_customer = app.get(image_customer)[0]
faces = app.get(image_background)
for face in faces:
image_background = swapper.get(image_background, face, face_customer, paste_back=True)
return image_background
def process(image):
images_background_encoding, images_background_contents = data_images['encodings'], data_images['content']
image_loaded = face_recognition.load_image_file(image)
face_encoding = face_recognition.face_encodings(image_loaded)[0]
face_distances = face_recognition.face_distance(images_background_encoding, face_encoding)
tmp_distance = face_distances[0]
tmp_content = images_background_contents[0]
for face_distance, images_background_content in zip(face_distances[1:], images_background_contents[1:]):
if tmp_distance > face_distance:
tmp_distance = face_distance
tmp_content = images_background_content
output_image = face_swapper(tmp_content, image_loaded)
return output_image
image_output = None
st.title('Change Faces')
option = st.radio('How would you like to upload your image?', ('File', 'WebCam'), horizontal=True)
if option=='File':
uploaded_file = st.file_uploader('Choose your image', type=['jpg', 'png', 'jpeg'])
else:
uploaded_file = st.camera_input("Take a picture")
if uploaded_file is not None:
bytes_data = uploaded_file.getvalue()
if option=='File':
st.image(uploaded_file)
if st.button('Process'):
with ThreadPoolExecutor() as executor:
future = executor.submit(process, uploaded_file)
image_output = future.result()
st.image(image_output)
if image_output is not None:
image_output_to_download = cv2.cvtColor(image_output, cv2.COLOR_BGR2RGB)
_, image_output_to_download = cv2.imencode('.jpg', image_output_to_download)
st.download_button('Download image', image_output_to_download.tobytes(), file_name=f'output_{uploaded_file.name}')