Spaces:
Sleeping
Sleeping
import requests | |
import pandas as pd | |
import time | |
from datetime import datetime | |
from dotenv import load_dotenv | |
import os | |
import gradio as gr | |
load_dotenv() | |
XAI_API_KEY = os.getenv("XAI_API_KEY") | |
# Global variable to store the most recent analysis results | |
GLOBAL_ANALYSIS_STORAGE = { | |
'subreddit': None, | |
'data': None | |
} | |
def call_LLM(query): | |
return call_groq(query) | |
def call_groq(query): | |
from groq import Groq | |
client = Groq() | |
chat_completion = client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": query} | |
], | |
model="llama3-8b-8192", | |
temperature=0.5, | |
max_tokens=1024, | |
top_p=1, | |
stop=None, | |
stream=False, | |
) | |
return chat_completion.choices[0].message.content | |
def process(row): | |
""" | |
Format this so that the model sees full post for now | |
""" | |
# title | |
# comment_body | |
prompt = f"The below is a reddit post. Take a look and tell me if there is a business problem to be solved here ||| title: {row['post_title']} ||| comment: {row['comment_body']}" | |
return call_LLM(prompt) | |
# ... [Keep previous helper functions like extract_comment_data, fetch_top_comments, fetch_subreddits, fetch_top_posts] ... | |
def extract_comment_data(comment, post_info): | |
"""Extract relevant data from a comment""" | |
return { | |
'subreddit': post_info['subreddit'], | |
'post_title': post_info['title'], | |
'post_score': post_info['score'], | |
'post_created_utc': post_info['created_utc'], | |
'comment_id': comment['data'].get('id'), | |
'comment_author': comment['data'].get('author'), | |
'comment_body': comment['data'].get('body'), | |
'comment_score': comment['data'].get('score', 0), | |
'comment_created_utc': datetime.fromtimestamp(comment['data'].get('created_utc', 0)), | |
'post_url': post_info['url'], | |
'comment_url': f"https://www.reddit.com{post_info['permalink']}{comment['data'].get('id')}", | |
} | |
def fetch_top_comments(post_df, num_comments=2): | |
""" | |
Fetch top comments for each post in the dataframe, sorted by upvotes | |
""" | |
all_comments = [] | |
total_posts = len(post_df) | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
} | |
print(f"\nFetching top {num_comments} most upvoted comments for {total_posts} posts...") | |
for idx, post in post_df.iterrows(): | |
print(f"\nProcessing post {idx + 1}/{total_posts}") | |
print(f"Title: {post['title'][:100]}...") | |
print(f"Post Score: {post['score']}, Number of Comments: {post['num_comments']}") | |
try: | |
json_url = post['permalink'].replace('https://www.reddit.com', '') + '.json' | |
url = f'https://www.reddit.com{json_url}' | |
response = requests.get(url, headers=headers) | |
response.raise_for_status() | |
data = response.json() | |
if len(data) > 1: | |
comments_data = data[1]['data']['children'] | |
# Filter out non-comment entries and extract scores | |
valid_comments = [ | |
comment for comment in comments_data | |
if comment['kind'] == 't1' and comment['data'].get('score') is not None | |
] | |
# Sort comments by score (upvotes) in descending order | |
sorted_comments = sorted( | |
valid_comments, | |
key=lambda x: x['data'].get('score', 0), | |
reverse=True | |
) | |
# Take only the top N comments | |
top_comments = sorted_comments[:num_comments] | |
# Print comment scores for verification | |
print("\nTop comment scores for this post:") | |
for i, comment in enumerate(top_comments, 1): | |
score = comment['data'].get('score', 0) | |
print(f"Comment {i}: {score} upvotes") | |
# Add to main list | |
for comment in top_comments: | |
all_comments.append(extract_comment_data(comment, post)) | |
time.sleep(2) | |
except requests.exceptions.RequestException as e: | |
print(f"Error fetching comments for post {idx + 1}: {e}") | |
continue | |
# Create DataFrame and sort | |
comments_df = pd.DataFrame(all_comments) | |
if not comments_df.empty: | |
# Verify sorting by showing top comments for each post | |
print("\nVerification of comment sorting:") | |
for post_title in comments_df['post_title'].unique(): | |
post_comments = comments_df[comments_df['post_title'] == post_title] | |
print(f"\nPost: {post_title[:100]}...") | |
print("Comment scores:", post_comments['comment_score'].tolist()) | |
return comments_df | |
def fetch_subreddits(limit=10, min_subscribers=1000): | |
""" | |
Fetch subreddits from Reddit | |
Args: | |
limit (int): Number of subreddits to fetch | |
min_subscribers (int): Minimum number of subscribers required | |
""" | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
} | |
subreddits_data = [] | |
after = None | |
while len(subreddits_data) < limit: | |
try: | |
url = f'https://www.reddit.com/subreddits/popular.json?limit=100' | |
if after: | |
url += f'&after={after}' | |
print(f"Fetching subreddits... Current count: {len(subreddits_data)}") | |
response = requests.get(url, headers=headers) | |
response.raise_for_status() | |
data = response.json() | |
for subreddit in data['data']['children']: | |
subreddit_data = subreddit['data'] | |
if subreddit_data.get('subscribers', 0) >= min_subscribers: | |
sub_info = { | |
'display_name': subreddit_data.get('display_name'), | |
'display_name_prefixed': subreddit_data.get('display_name_prefixed'), | |
'title': subreddit_data.get('title'), | |
'subscribers': subreddit_data.get('subscribers', 0), | |
'active_users': subreddit_data.get('active_user_count', 0), | |
'created_utc': datetime.fromtimestamp(subreddit_data.get('created_utc', 0)), | |
'description': subreddit_data.get('description'), | |
'subreddit_type': subreddit_data.get('subreddit_type'), | |
'over18': subreddit_data.get('over18', False), | |
'url': f"https://www.reddit.com/r/{subreddit_data.get('display_name')}/" | |
} | |
subreddits_data.append(sub_info) | |
after = data['data'].get('after') | |
if not after: | |
print("Reached end of listings") | |
break | |
time.sleep(2) | |
except requests.exceptions.RequestException as e: | |
print(f"Error fetching data: {e}") | |
break | |
return pd.DataFrame(subreddits_data) | |
def fetch_top_posts(subreddit, limit=5): | |
""" | |
Fetch top posts from a subreddit using Reddit's JSON API | |
Args: | |
subreddit (str): Name of the subreddit without the 'r/' | |
limit (int): Maximum number of posts to fetch | |
Returns: | |
list: List of post dictionaries | |
""" | |
posts_data = [] | |
url = f'https://www.reddit.com/r/{subreddit}/top.json?t=all&limit={limit}' | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
} | |
try: | |
response = requests.get(url, headers=headers) | |
response.raise_for_status() | |
data = response.json() | |
for post in data['data']['children']: | |
post_data = post['data'] | |
posts_data.append({ | |
'subreddit': subreddit, | |
'title': post_data.get('title'), | |
'score': post_data.get('score'), | |
'num_comments': post_data.get('num_comments'), | |
'created_utc': datetime.fromtimestamp(post_data.get('created_utc', 0)), | |
'url': post_data.get('url'), | |
'permalink': 'https://www.reddit.com' + post_data.get('permalink', '') | |
}) | |
time.sleep(2) | |
except requests.exceptions.RequestException as e: | |
print(f"Error fetching posts from r/{subreddit}: {e}") | |
return pd.DataFrame(posts_data) | |
def show_dataframe(subreddit): | |
# Fetch top posts | |
top_posts = fetch_top_posts(subreddit) | |
# Fetch top comments for these posts | |
data_to_analyze = fetch_top_comments(top_posts) | |
# Process and analyze each comment | |
responses = [] | |
for _, row in data_to_analyze.iterrows(): | |
print(f"{_} done") | |
responses.append(process(row)) | |
# Add analysis to the dataframe | |
data_to_analyze['analysis'] = responses | |
# Store in global storage for quick access | |
GLOBAL_ANALYSIS_STORAGE['subreddit'] = subreddit | |
GLOBAL_ANALYSIS_STORAGE['data'] = data_to_analyze | |
return data_to_analyze | |
def launch_interface(): | |
# Fetch list of subreddits for user to choose from | |
sub_reddits = fetch_subreddits() | |
subreddit_list = sub_reddits["display_name"].tolist() | |
# Create Gradio Blocks for more flexible interface | |
with gr.Blocks() as demo: | |
# Title and description | |
gr.Markdown("# Reddit Business Problem Analyzer") | |
gr.Markdown("Discover potential business opportunities from Reddit discussions") | |
# Subreddit selection | |
subreddit_dropdown = gr.Dropdown( | |
choices=subreddit_list, | |
label="Select Subreddit", | |
info="Choose a subreddit to analyze" | |
) | |
# Outputs | |
with gr.Row(): | |
with gr.Column(): | |
# Overall Analysis Section | |
gr.Markdown("## Overall Analysis") | |
# overall_analysis = gr.Textbox( | |
# label="Aggregated Business Insights", | |
# interactive=False, | |
# lines=5 | |
# ) | |
# Results Table | |
results_table = gr.Dataframe( | |
label="Analysis Results", | |
headers=["Index", "Post Title", "Comment", "Analysis"], | |
interactive=False | |
) | |
# Row Selection | |
row_index = gr.Number( | |
label="Select Row Index for Detailed View", | |
precision=0 | |
) | |
with gr.Column(): | |
# Detailed Post Analysis | |
gr.Markdown("## Detailed Post Analysis") | |
detailed_analysis = gr.Markdown( | |
label="Detailed Insights" | |
) | |
# Function to update posts when subreddit is selected | |
def update_posts(subreddit): | |
# Fetch and analyze data | |
data_to_analyze = show_dataframe(subreddit) | |
# Prepare table data | |
table_data = data_to_analyze[['post_title', 'comment_body', 'analysis']].reset_index() | |
table_data.columns = ['Index', 'Post Title', 'Comment', 'Analysis'] | |
return table_data, None | |
# Function to show detailed analysis for a specific row | |
def show_row_details(row_index): | |
# Ensure we have data loaded | |
if GLOBAL_ANALYSIS_STORAGE['data'] is None: | |
return "Please select a subreddit first." | |
try: | |
# Convert to integer and subtract 1 (since index is 0-based) | |
row_index = int(row_index) | |
# Retrieve the specific row | |
row_data = GLOBAL_ANALYSIS_STORAGE['data'].loc[row_index] | |
# Format detailed view | |
detailed_view = f""" | |
### Post Details | |
**Title:** {row_data.get('post_title', 'N/A')} | |
**Comment:** {row_data.get('comment_body', 'N/A')} | |
**Comment Score:** {row_data.get('comment_score', 'N/A')} | |
**Analysis:** {row_data.get('analysis', 'No analysis available')} | |
**Post URL:** {row_data.get('post_url', 'N/A')} | |
**Comment URL:** {row_data.get('comment_url', 'N/A')} | |
""" | |
return detailed_view | |
except (KeyError, ValueError, TypeError) as e: | |
return f"Error retrieving row details: {str(e)}" | |
# Event Listeners | |
subreddit_dropdown.change( | |
fn=update_posts, | |
inputs=subreddit_dropdown, | |
outputs=[results_table, detailed_analysis] | |
) | |
row_index.change( | |
fn=show_row_details, | |
inputs=row_index, | |
outputs=detailed_analysis | |
) | |
return demo | |
# Launch the interface | |
if __name__ == "__main__": | |
interface = launch_interface() | |
interface.launch(share=True) |