rameshmoorthy's picture
Update functions.py
06350c8 verified
raw
history blame
7.6 kB
import pandas as pd
from bertopic import BERTopic
from huggingface_hub import InferenceClient
from bertopic.vectorizers import ClassTfidfTransformer
from sentence_transformers import SentenceTransformer
from sklearn import preprocessing
from sklearn.preprocessing import LabelEncoder
from tempfile import NamedTemporaryFile
import matplotlib.pyplot as plt
import plotly.express as px
import subprocess
from wordcloud import WordCloud
def process_file_bm25(file,mode,min_cluster_size,top_n_words,ngram):
# Run the shell command and capture its output
x = subprocess.check_output(["pip", "show", "scipy"])
# Decode the byte string output to a regular string
x = x.decode("utf-8")
# Run the shell command and capture its output
y = subprocess.check_output(["pip", "show", "numpy"])
# Decode the byte string output to a regular string
y = y.decode("utf-8")
# Run the shell command and capture its output
z = subprocess.check_output(["pip", "show", "plotly"])
# Decode the byte string output to a regular string
z = z.decode("utf-8")
# Print the output
print(x,y,z)
if file.name.endswith('.csv'):
df = pd.read_csv(file)
elif file.name.endswith('.xls') or file.name.endswith('.xlsx'):
df = pd.read_excel(file)
else:
raise ValueError("Unsupported file format. Please provide a CSV or Excel file.")
# Ensure that the 'products' column is present in the dataframe
if 'products' not in df.columns.str.lower():
raise ValueError("The input file must have a column named 'products'.")
# Convert the 'products' column to a list
sentences_list = df['products'].tolist()
print(len(sentences_list))
ctfidf_model = ClassTfidfTransformer(bm25_weighting=True,reduce_frequent_words=True)
if mode=="Automated clustering":
topic_model = BERTopic(ctfidf_model=ctfidf_model,n_gram_range =(1,ngram),top_n_words=top_n_words)
else:
topic_model = BERTopic(ctfidf_model=ctfidf_model,n_gram_range =(1,ngram),top_n_words=top_n_words,min_topic_size=min_cluster_size)
# Perform topic modeling
topics, probabilities = topic_model.fit_transform(sentences_list)
# Visualize all graphs
topics_info=topic_model.get_topic_info()
df_topics_bm25= topics_info
#print(topics)
try:
barchart = topic_model.visualize_barchart(top_n_topics=10)
except:
barchart='Error message'
try:
topics_plot = topic_model.visualize_topics()
except:
topics_plot = ' Error message'
heatmap = topic_model.visualize_heatmap()
hierarchy = topic_model.visualize_hierarchy()
df['topic_number'] = topics
# Encode the topic numbers to make them categorical
label_encoder = LabelEncoder()
df['topic_number_encoded'] = label_encoder.fit_transform(df['topic_number'])
temp_file = NamedTemporaryFile(delete=False, suffix=".xlsx")
df.to_excel(temp_file.name, index=False)
df_bm25=df
#print(df)
return df,temp_file.name,topics_info ,barchart,topics_plot, heatmap, hierarchy
def process_file_bert(file,mode,min_cluster_size,top_n_words,ngram):
# Read the Excel sheet or CSV file
if file.name.endswith('.csv'):
df = pd.read_csv(file)
elif file.name.endswith('.xls') or file.name.endswith('.xlsx'):
df = pd.read_excel(file)
else:
raise ValueError("Unsupported file format. Please provide a CSV or Excel file.")
# Ensure that the 'products' column is present in the dataframe
if 'products' not in df.columns.str.lower():
raise ValueError("The input file must have a column named 'products'.")
# Convert the 'products' column to a list
sentences_list = df['products'].tolist()
print(len(sentences_list))
representation_model = KeyBERTInspired()
if mode=="Automated clustering":
# Fine-tune your topic representations
topic_model = BERTopic(representation_model=representation_model,n_gram_range =(1,ngram),top_n_words=top_n_words)
else:
topic_model = BERTopic(representation_model=representation_model,n_gram_range =(1,ngram),top_n_words=top_n_words,min_topic_size=min_cluster_size)
topics, probabilities = topic_model.fit_transform(sentences_list)
# Visualize all graphs
topics_info=topic_model.get_topic_info()
state.df_topics_bert= topics_info
#print(topics)
try:
barchart = topic_model.visualize_barchart(top_n_topics=10)
except:
barchart='Error message'
try:
topics_plot = topic_model.visualize_topics()
except:
topics_plot = ' Error message'
heatmap = topic_model.visualize_heatmap()
hierarchy = topic_model.visualize_hierarchy()
df['topic_number'] = topics
# Encode the topic numbers to make them categorical
label_encoder = LabelEncoder()
df['topic_number_encoded'] = label_encoder.fit_transform(df['topic_number'])
temp_file = NamedTemporaryFile(delete=False, suffix=".xlsx")
df.to_excel(temp_file.name, index=False)
state.df_bert=df
return df, topics_info ,barchart,topics_plot, heatmap, hierarchy
client = InferenceClient(
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(
prompt, history, system_prompt, temperature=0.9, max_new_tokens=4096, top_p=0.95, repetition_penalty=1.0,
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
return output
# Define the function to generate the plot based on user inputs
def generate_plot(topic, x_axis_index, y_axis_index, chart_type, agg_func):
x_axis = df.columns[1:][x_axis_index]
y_axis = df.columns[1:][y_axis_index]
print(x_axis,y_axis)
filtered_df = df[df['Topic Number'] == topic]
if chart_type == "scatter":
fig = px.scatter(filtered_df, x=x_axis, y=y_axis)
elif chart_type == "bar":
print('Bar chart selected')
if agg_func == "count_distinct":
fig = px.bar(filtered_df, x=x_axis, y=y_axis, color=y_axis, barmode='group')
else:
fig = px.bar(filtered_df, x=x_axis, y=y_axis, color=y_axis)
elif chart_type == "line":
fig = px.line(filtered_df, x=x_axis, y=y_axis)
elif chart_type == "box":
fig = px.box(filtered_df, x=x_axis, y=y_axis)
elif chart_type == "wordcloud":
text = ' '.join(filtered_df[y_axis].astype(str))
wordcloud = WordCloud(width=800, height=400, random_state=21, max_font_size=110).generate(text)
plt.figure(figsize=(10, 7))
plt.imshow(wordcloud, interpolation="bilinear")
plt.axis('off')
plt.show()
return None
elif chart_type == "pie":
fig = px.pie(filtered_df, names=x_axis, values=y_axis)
print('Pie chart selected')
return fig