Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import torch | |
from transformers import TapexTokenizer, BartForConditionalGeneration | |
import xml.etree.ElementTree as ET | |
from io import StringIO | |
import logging | |
from datetime import datetime | |
import time | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
def load_model(): | |
""" | |
Load and cache the TAPEX model and tokenizer using Streamlit's caching | |
""" | |
try: | |
tokenizer = TapexTokenizer.from_pretrained( | |
"microsoft/tapex-large-finetuned-wtq", | |
model_max_length=1024 | |
) | |
model = BartForConditionalGeneration.from_pretrained( | |
"microsoft/tapex-large-finetuned-wtq" | |
) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model = model.to(device) | |
model.eval() | |
return tokenizer, model | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
return None, None | |
def parse_xml_to_dataframe(xml_string: str): | |
""" | |
Parse XML string to DataFrame with error handling | |
""" | |
try: | |
tree = ET.parse(StringIO(xml_string)) | |
root = tree.getroot() | |
data = [] | |
columns = set() | |
# First pass: collect all possible columns | |
for record in root.findall('.//record'): | |
columns.update(elem.tag for elem in record) | |
# Second pass: create data rows | |
for record in root.findall('.//record'): | |
row_data = {col: None for col in columns} | |
for elem in record: | |
row_data[elem.tag] = elem.text | |
data.append(row_data) | |
df = pd.DataFrame(data) | |
# Convert numeric columns (automatically detect) | |
for col in df.columns: | |
try: | |
df[col] = pd.to_numeric(df[col]) | |
except: | |
continue | |
return df, None | |
except Exception as e: | |
return None, f"Error parsing XML: {str(e)}" | |
def process_query(tokenizer, model, df, query: str): | |
""" | |
Process a single query using the TAPEX model | |
""" | |
try: | |
start_time = time.time() | |
# Handle direct DataFrame operations for common queries | |
query_lower = query.lower() | |
if "highest" in query_lower or "maximum" in query_lower: | |
for col in df.select_dtypes(include=['number']).columns: | |
if col.lower() in query_lower: | |
return df.loc[df[col].idxmax()].to_dict() | |
elif "average" in query_lower or "mean" in query_lower: | |
for col in df.select_dtypes(include=['number']).columns: | |
if col.lower() in query_lower: | |
return f"Average {col}: {df[col].mean():.2f}" | |
elif "total" in query_lower or "sum" in query_lower: | |
for col in df.select_dtypes(include=['number']).columns: | |
if col.lower() in query_lower: | |
return f"Total {col}: {df[col].sum():.2f}" | |
# Use TAPEX for more complex queries | |
with torch.no_grad(): | |
encoding = tokenizer( | |
table=df.astype(str), | |
query=query, | |
return_tensors="pt", | |
padding=True, | |
truncation=True | |
) | |
outputs = model.generate(**encoding) | |
answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
processing_time = time.time() - start_time | |
return f"Answer: {answer} (Processing time: {processing_time:.2f}s)" | |
except Exception as e: | |
return f"Error processing query: {str(e)}" | |
def main(): | |
st.title("XML Data Query System") | |
st.write("Upload your XML data and ask questions about it!") | |
# Initialize session state for XML input and query if not exists | |
if 'xml_input' not in st.session_state: | |
st.session_state.xml_input = "" | |
if 'current_query' not in st.session_state: | |
st.session_state.current_query = "" | |
# Load model | |
with st.spinner("Loading TAPEX model... (this may take a few moments)"): | |
tokenizer, model = load_model() | |
if tokenizer is None or model is None: | |
st.error("Failed to load the model. Please refresh the page.") | |
return | |
# XML Input | |
xml_input = st.text_area( | |
"Enter your XML data here:", | |
value=st.session_state.xml_input, | |
height=200, | |
help="Paste your XML data here. Make sure it's properly formatted." | |
) | |
# Sample XML button | |
if st.button("Load Sample XML"): | |
st.session_state.xml_input = """<?xml version="1.0" encoding="UTF-8"?> | |
<data> | |
<records> | |
<record> | |
<company>Apple</company> | |
<revenue>365.7</revenue> | |
<employees>147000</employees> | |
<year>2021</year> | |
</record> | |
<record> | |
<company>Microsoft</company> | |
<revenue>168.1</revenue> | |
<employees>181000</employees> | |
<year>2021</year> | |
</record> | |
<record> | |
<company>Amazon</company> | |
<revenue>386.1</revenue> | |
<employees>1608000</employees> | |
<year>2021</year> | |
</record> | |
</records> | |
</data>""" | |
st.rerun() | |
if xml_input: | |
df, error = parse_xml_to_dataframe(xml_input) | |
if error: | |
st.error(error) | |
else: | |
st.success("XML parsed successfully!") | |
# Display DataFrame | |
st.subheader("Parsed Data:") | |
st.dataframe(df) | |
# Query input | |
query = st.text_input( | |
"Enter your question about the data:", | |
value=st.session_state.current_query, | |
help="Example: 'Which company has the highest revenue?'" | |
) | |
# Process query | |
if query: | |
with st.spinner("Processing query..."): | |
result = process_query(tokenizer, model, df, query) | |
st.write(result) | |
# Sample queries | |
st.subheader("Sample Questions (Click to use):") | |
sample_queries = [ | |
"Which company has the highest revenue?", | |
"What is the average revenue of all companies?", | |
"How many employees does Microsoft have?", | |
"Which company has the most employees?", | |
"What is the total revenue of all companies?" | |
] | |
# Create columns for sample query buttons | |
cols = st.columns(len(sample_queries)) | |
for idx, (col, sample_query) in enumerate(zip(cols, sample_queries)): | |
with col: | |
if st.button(f"Query {idx + 1}", help=sample_query, key=f"query_btn_{idx}"): | |
st.session_state.current_query = sample_query | |
st.rerun() | |
# Display the sample queries as text for reference | |
with st.expander("View all sample questions"): | |
for idx, query in enumerate(sample_queries, 1): | |
st.write(f"{idx}. {query}") | |
if __name__ == "__main__": | |
main() | |