LexAIcon / app.py
manuelcozar55's picture
Update app.py
30ed7b0 verified
raw
history blame
No virus
5.95 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from huggingface_hub import login
from PyPDF2 import PdfReader
from docx import Document
import csv
import json
import os
import torch
huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
# Realizar el inicio de sesi贸n de Hugging Face solo si el token est谩 disponible
if huggingface_token:
login(token=huggingface_token)
# Configuraci贸n del modelo
@st.cache_resource
def load_llm():
llm = HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.3",
task="text-generation"
)
llm_engine_hf = ChatHuggingFace(llm=llm)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
return llm_engine_hf, tokenizer
llm_engine_hf, tokenizer = load_llm()
# Configuraci贸n del modelo de clasificaci贸n
@st.cache_resource
def load_classification_model():
tokenizer = AutoTokenizer.from_pretrained("mrm8488/legal-longformer-base-8192-spanish")
model = AutoModelForSequenceClassification.from_pretrained("mrm8488/legal-longformer-base-8192-spanish")
return model, tokenizer
classification_model, classification_tokenizer = load_classification_model()
id2label = {0: "multas", 1: "politicas_de_privacidad", 2: "contratos", 3: "denuncias", 4: "otros"}
def classify_text(text):
inputs = classification_tokenizer(text, return_tensors="pt", max_length=4096, truncation=True, padding="max_length")
classification_model.eval()
with torch.no_grad():
outputs = classification_model(**inputs)
logits = outputs.logits
predicted_class_id = logits.argmax(dim=-1).item()
predicted_label = id2label[predicted_class_id]
return f"Clasificaci贸n: {predicted_label}\n\nDocumento:\n{text}"
def translate(text, target_language):
template = '''
Por favor, traduzca el siguiente documento al {LANGUAGE}:
<document>
{TEXT}
</document>
Aseg煤rese de que la traducci贸n sea precisa y conserve el significado original del documento.
'''
formatted_prompt = template.replace("{TEXT}", text).replace("{LANGUAGE}", target_language)
inputs = tokenizer(formatted_prompt, return_tensors="pt")
outputs = llm_engine_hf.invoke(formatted_prompt)
translated_text = outputs.content
return translated_text
def summarize(text, length):
template = f'''
Por favor, haga un resumen {length} del siguiente documento:
<document>
{text}
</document>
Aseg煤rese de que el resumen sea conciso y conserve el significado original del documento.
'''
inputs = tokenizer(template, return_tensors="pt")
outputs = llm_engine_hf.invoke(template)
summarized_text = outputs.content
return summarized_text
def handle_uploaded_file(uploaded_file):
try:
if uploaded_file.name.endswith(".txt"):
text = uploaded_file.read().decode("utf-8")
elif uploaded_file.name.endswith(".pdf"):
reader = PdfReader(uploaded_file)
text = ""
for page in range(len(reader.pages)):
text += reader.pages[page].extract_text()
elif uploaded_file.name.endswith(".docx"):
doc = Document(uploaded_file)
text = "\n".join([para.text for para in doc.paragraphs])
elif uploaded_file.name.endswith(".csv"):
text = ""
content = uploaded_file.read().decode("utf-8").splitlines()
reader = csv.reader(content)
text = " ".join([" ".join(row) for row in reader])
elif uploaded_file.name.endswith(".json"):
data = json.load(uploaded_file)
text = json.dumps(data, indent=4)
else:
text = "Tipo de archivo no soportado."
return text
except Exception as e:
return str(e)
st.title("LexAIcon")
st.write("Puedes conversar con este chatbot basado en Mistral7B-Instruct y subir archivos para que el chatbot los procese.")
if "generated" not in st.session_state:
st.session_state["generated"] = []
if "past" not in st.session_state:
st.session_state["past"] = []
# Entrada del usuario
user_input = st.text_input("T煤: ", "")
# Opciones para la traducci贸n
target_language = st.selectbox("Selecciona el idioma de traducci贸n", ["espa帽ol", "ingl茅s", "franc茅s", "alem谩n"])
# Opciones para el resumen
summary_length = st.selectbox("Selecciona la longitud del resumen", ["corto", "medio", "largo"])
# Manejo de archivos subidos
uploaded_files = st.file_uploader("Sube un archivo", type=["txt", "pdf", "docx", "csv", "json"], accept_multiple_files=True)
if st.button("Enviar"):
if user_input:
response = generate_response(user_input)
st.session_state.generated.append({"user": user_input, "bot": response})
# Botones de Resumir, Traducir y Explicar
operation = st.radio("Selecciona una operaci贸n", ["Resumir", "Traducir", "Explicar"])
if st.button("Ejecutar"):
if uploaded_files:
for uploaded_file in uploaded_files:
file_content = handle_uploaded_file(uploaded_file)
if operation == "Resumir":
if summary_length == "corto":
length = "de aproximadamente 50 palabras"
elif summary_length == "medio":
length = "de aproximadamente 100 palabras"
elif summary_length == "largo":
length = "de aproximadamente 500 palabras"
result = summarize(file_content, length)
elif operation == "Traducir":
result = translate(file_content, target_language)
elif operation == "Explicar":
result = classify_text(file_content)
st.write(result)
if st.session_state.get("generated"):
for chat in st.session_state["generated"]:
st.write(f"T煤: {chat['user']}")
st.write(f"Chatbot: {chat['bot']}")