rameshmoorthy's picture
Update functions.py
169a56e verified
raw
history blame
6.97 kB
import pandas as pd
from bertopic import BERTopic
from huggingface_hub import InferenceClient
from bertopic.vectorizers import ClassTfidfTransformer
from sentence_transformers import SentenceTransformer
from sklearn import preprocessing
from sklearn.preprocessing import LabelEncoder
from tempfile import NamedTemporaryFile
import matplotlib.pyplot as plt
import plotly.express as px
from wordcloud import WordCloud
def process_file_bm25(file,mode,min_cluster_size,top_n_words,ngram):
# Read the Excel sheet or CSV file
if file.name.endswith('.csv'):
df = pd.read_csv(file)
elif file.name.endswith('.xls') or file.name.endswith('.xlsx'):
df = pd.read_excel(file)
else:
raise ValueError("Unsupported file format. Please provide a CSV or Excel file.")
# Ensure that the 'products' column is present in the dataframe
if 'products' not in df.columns.str.lower():
raise ValueError("The input file must have a column named 'products'.")
# Convert the 'products' column to a list
sentences_list = df['products'].tolist()
print(len(sentences_list))
ctfidf_model = ClassTfidfTransformer(bm25_weighting=True,reduce_frequent_words=True)
if mode=="Automated clustering":
topic_model = BERTopic(ctfidf_model=ctfidf_model,n_gram_range =(1,ngram),top_n_words=top_n_words)
else:
topic_model = BERTopic(ctfidf_model=ctfidf_model,n_gram_range =(1,ngram),top_n_words=top_n_words,min_topic_size=min_cluster_size)
# Perform topic modeling
topics, probabilities = topic_model.fit_transform(sentences_list)
# Visualize all graphs
topics_info=topic_model.get_topic_info()
df_topics_bm25= topics_info
#print(topics)
try:
barchart = topic_model.visualize_barchart(top_n_topics=10)
except:
barchart='Error message'
try:
topics_plot = topic_model.visualize_topics()
except:
topics_plot = ' Error message'
heatmap = topic_model.visualize_heatmap()
hierarchy = topic_model.visualize_hierarchy()
df['topic_number'] = topics
# Encode the topic numbers to make them categorical
label_encoder = LabelEncoder()
df['topic_number_encoded'] = label_encoder.fit_transform(df['topic_number'])
temp_file = NamedTemporaryFile(delete=False, suffix=".xlsx")
df.to_excel(temp_file.name, index=False)
df_bm25=df
#print(df)
return df,temp_file.name,topics_info ,barchart,topics_plot, heatmap, hierarchy
def process_file_bert(file,mode,min_cluster_size,top_n_words,ngram):
# Read the Excel sheet or CSV file
if file.name.endswith('.csv'):
df = pd.read_csv(file)
elif file.name.endswith('.xls') or file.name.endswith('.xlsx'):
df = pd.read_excel(file)
else:
raise ValueError("Unsupported file format. Please provide a CSV or Excel file.")
# Ensure that the 'products' column is present in the dataframe
if 'products' not in df.columns.str.lower():
raise ValueError("The input file must have a column named 'products'.")
# Convert the 'products' column to a list
sentences_list = df['products'].tolist()
print(len(sentences_list))
representation_model = KeyBERTInspired()
if mode=="Automated clustering":
# Fine-tune your topic representations
topic_model = BERTopic(representation_model=representation_model,n_gram_range =(1,ngram),top_n_words=top_n_words)
else:
topic_model = BERTopic(representation_model=representation_model,n_gram_range =(1,ngram),top_n_words=top_n_words,min_topic_size=min_cluster_size)
topics, probabilities = topic_model.fit_transform(sentences_list)
# Visualize all graphs
topics_info=topic_model.get_topic_info()
state.df_topics_bert= topics_info
#print(topics)
try:
barchart = topic_model.visualize_barchart(top_n_topics=10)
except:
barchart='Error message'
try:
topics_plot = topic_model.visualize_topics()
except:
topics_plot = ' Error message'
heatmap = topic_model.visualize_heatmap()
hierarchy = topic_model.visualize_hierarchy()
df['topic_number'] = topics
# Encode the topic numbers to make them categorical
label_encoder = LabelEncoder()
df['topic_number_encoded'] = label_encoder.fit_transform(df['topic_number'])
temp_file = NamedTemporaryFile(delete=False, suffix=".xlsx")
df.to_excel(temp_file.name, index=False)
state.df_bert=df
return df, topics_info ,barchart,topics_plot, heatmap, hierarchy
client = InferenceClient(
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(
prompt, history, system_prompt, temperature=0.9, max_new_tokens=4096, top_p=0.95, repetition_penalty=1.0,
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
return output
# Define the function to generate the plot based on user inputs
def generate_plot(topic, x_axis_index, y_axis_index, chart_type, agg_func):
x_axis = df.columns[1:][x_axis_index]
y_axis = df.columns[1:][y_axis_index]
print(x_axis,y_axis)
filtered_df = df[df['Topic Number'] == topic]
if chart_type == "scatter":
fig = px.scatter(filtered_df, x=x_axis, y=y_axis)
elif chart_type == "bar":
print('Bar chart selected')
if agg_func == "count_distinct":
fig = px.bar(filtered_df, x=x_axis, y=y_axis, color=y_axis, barmode='group')
else:
fig = px.bar(filtered_df, x=x_axis, y=y_axis, color=y_axis)
elif chart_type == "line":
fig = px.line(filtered_df, x=x_axis, y=y_axis)
elif chart_type == "box":
fig = px.box(filtered_df, x=x_axis, y=y_axis)
elif chart_type == "wordcloud":
text = ' '.join(filtered_df[y_axis].astype(str))
wordcloud = WordCloud(width=800, height=400, random_state=21, max_font_size=110).generate(text)
plt.figure(figsize=(10, 7))
plt.imshow(wordcloud, interpolation="bilinear")
plt.axis('off')
plt.show()
return None
elif chart_type == "pie":
fig = px.pie(filtered_df, names=x_axis, values=y_axis)
print('Pie chart selected')
return fig