groundbi-factory / discord_utils.py
naonauno's picture
Rename discord_utils (47).py to discord_utils.py
473d58f verified
import os
import discord
from discord import app_commands
from discord.ext import commands
import torch
from PIL import Image
import io
import asyncio
import aiohttp
import gc
from concurrent.futures import ThreadPoolExecutor
import time
import random
class GenerationQueue:
def __init__(self):
self.queue = asyncio.Queue()
self.active_generations = {}
self.position_messages = {}
def get_position(self, user_id):
position = 1
for interaction, _ in self.queue._queue:
if interaction.user.id == user_id:
return position
position += 1
return 0
async def update_positions(self):
current_size = self.queue.qsize()
for user_id, msg in self.position_messages.items():
try:
new_position = self.get_position(user_id)
if new_position > 0:
await msg.edit(content=f"Queue position: {new_position}")
else:
await msg.delete()
del self.position_messages[user_id]
except:
continue
async def add_generation(self, interaction, params):
position = self.queue.qsize() + 1
await self.queue.put((interaction, params))
if position > 1:
msg = await interaction.followup.send(f"Queue position: {position}")
self.position_messages[interaction.user.id] = msg
return position
async def process_queue(self):
while True:
try:
interaction, params = await self.queue.get()
if interaction.response.is_done():
self.queue.task_done()
continue
self.active_generations[interaction.id] = params
await interaction.client.tree.get_command('generate').callback(
interaction,
**params
)
await self.update_positions()
self.queue.task_done()
except Exception as e:
print(f"Queue processing error: {e}")
self.queue.task_done()
class RegenerateView(discord.ui.View):
def __init__(self, prompt, negative_prompt, reference_image, guidance_scale, steps, strength, width, height, eta, clip_skip, cross_attention_scale, use_custom_size, num_images, user_id):
super().__init__(timeout=None)
self.params = {
'prompt': prompt,
'negative_prompt': negative_prompt,
'reference_image': reference_image,
'guidance_scale': guidance_scale,
'steps': steps,
'strength': strength,
'width': width,
'height': height,
'eta': eta,
'clip_skip': clip_skip,
'cross_attention_scale': cross_attention_scale,
'use_custom_size': use_custom_size,
'num_images': num_images,
'user_id': user_id
}
@discord.ui.button(label="๐Ÿ”„ Regenerate", style=discord.ButtonStyle.green)
async def regenerate(self, interaction: discord.Interaction, button: discord.ui.Button):
try:
ref_image = self.params['reference_image']
command = interaction.client.tree.get_command('generate')
params = {k: v for k, v in self.params.items() if k != 'reference_image' and k != 'user_id'}
params['seed'] = random.randint(1, 1000000) # Randomize seed for regeneration
await command.callback(
interaction,
reference_image=ref_image,
**params
)
except Exception as e:
try:
await interaction.followup.send(f"Regeneration failed: {str(e)}", ephemeral=True)
except:
await interaction.response.send_message(f"Regeneration failed: {str(e)}", ephemeral=True)
class GuideImageView(discord.ui.View):
def __init__(self):
super().__init__(timeout=None)
self.current_image = None
self.base_url = "https://naonauno-groundbi-factory.hf.space/guide"
self.min_id = 1
self.max_id = 6564 # Adjust this to match your actual maximum ID
async def get_random_guide_image(self):
while True:
random_id = random.randint(self.min_id, self.max_id)
url = f"{self.base_url}/{random_id}.png"
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
return await response.read(), url
elif response.status == 404:
continue # Try another random ID if file doesn't exist
else:
return None, None
@discord.ui.button(label="๐ŸŽฒ Reroll", style=discord.ButtonStyle.primary)
async def reroll(self, interaction: discord.Interaction, button: discord.ui.Button):
await interaction.response.defer()
image_data, url = await self.get_random_guide_image()
if image_data:
self.current_image = image_data
embed = discord.Embed(title="๐ŸŽฒ Random Guide Image", color=discord.Color.blue())
embed.set_image(url=url)
await interaction.followup.send(embed=embed, view=self)
@discord.ui.button(label="๐ŸŒŠ Generate with Water", style=discord.ButtonStyle.success)
async def generate_water(self, interaction: discord.Interaction, button: discord.ui.Button):
if self.current_image:
await self._generate(interaction, "terrain with water features, lakes, rivers, and coastlines")
@discord.ui.button(label="๐Ÿ”๏ธ Generate without Water", style=discord.ButtonStyle.success)
async def generate_no_water(self, interaction: discord.Interaction, button: discord.ui.Button):
if self.current_image:
await self._generate(interaction, "dry terrain, mountains, valleys, and deserts")
async def _generate(self, interaction, prompt_prefix):
if not self.current_image:
await interaction.response.send_message("Please select a guide image first!", ephemeral=True)
return
reference_image = discord.File(io.BytesIO(self.current_image), filename="guide.png")
command = interaction.client.tree.get_command('generate')
await command.callback(
interaction,
prompt=f"{prompt_prefix}, highly detailed terrain, realistic, geological features",
reference_image=reference_image,
guidance_scale=1.0,
steps=30,
strength=0.8
)
def setup_discord_bot(generate_image_func):
ALLOWED_SERVER_ID = os.getenv('DISCORD_SERVER_ID')
queue_manager = GenerationQueue()
intents = discord.Intents.default()
intents.message_content = True
intents.reactions = True
bot = commands.Bot(command_prefix='/', intents=intents)
bot.generate_image_func = generate_image_func
bot.queue_manager = queue_manager
async def cycle_activities():
activities = [
("watching", "for /generate commands ๐ŸŽจ"),
("playing", "with GroundBi maps โ›ฐ๏ธ"),
("listening", "to your prompts ๐ŸŽง")
]
while True:
for activity_type, name in activities:
activity = discord.Activity(
type=getattr(discord.ActivityType, activity_type),
name=name
)
await bot.change_presence(activity=activity)
await asyncio.sleep(50)
@bot.tree.command(name="quickgen", description="Generate terrain using a random guide image")
async def quickgen_command(interaction: discord.Interaction):
await interaction.response.defer()
try:
view = GuideImageView() # Remove arguments
image_data, _ = await view.get_random_guide_image()
if image_data:
reference_image = discord.File(io.BytesIO(image_data), filename="guide.png")
await bot.tree.get_command('generate').callback(
interaction,
prompt="highly detailed terrain, realistic, geological features",
reference_image=reference_image,
guidance_scale=1.0,
steps=30,
strength=0.8
)
else:
await interaction.followup.send("Failed to fetch guide image")
except Exception as e:
await interaction.followup.send(f"Error: {str(e)}")
@bot.tree.command(name="guidegen", description="Choose a guide image and generate terrain")
async def guidegen_command(interaction: discord.Interaction):
await interaction.response.defer()
try:
view = GuideImageView() # Remove arguments
image_data, url = await view.get_random_guide_image()
if image_data:
view.current_image = image_data
embed = discord.Embed(title="๐ŸŽฒ Random Guide Image", color=discord.Color.blue())
embed.set_image(url=url)
await interaction.followup.send(embed=embed, view=view)
else:
await interaction.followup.send("Failed to fetch guide image")
except Exception as e:
await interaction.followup.send(f"Error: {str(e)}")
class GuideRerollView(discord.ui.View):
def __init__(self):
super().__init__(timeout=None)
@discord.ui.button(label="๐ŸŽฒ Reroll", style=discord.ButtonStyle.primary)
async def reroll(self, interaction: discord.Interaction, button: discord.ui.Button):
try:
random_id = random.randint(1, 6564)
path = f"guide/{random_id}.png"
file = discord.File(path, filename=f"{random_id}.png")
await interaction.response.edit_message(attachments=[file], view=self)
except Exception as e:
await interaction.response.send_message(f"Error: {str(e)}", ephemeral=True)
@bot.tree.command(name="randomguide", description="Get a random guide image")
async def dev1_command(interaction: discord.Interaction):
await interaction.response.defer()
try:
random_id = random.randint(1, 6564)
path = f"guide/{random_id}.png"
with open(path, 'rb') as img:
file = discord.File(img, filename=f"{random_id}.png")
view = GuideRerollView()
await interaction.followup.send(file=file, view=view)
except Exception as e:
await interaction.followup.send(f"Error: {str(e)}")
@bot.tree.command(name="dev2", description="Generate terrain from a random guide image")
async def dev2_command(interaction: discord.Interaction):
try:
random_id = random.randint(1, 6564)
path = f"guide/{random_id}.png"
# Read the file content
with open(path, 'rb') as img_file:
# Create a dummy attachment-like object with read method
class DummyAttachment:
def __init__(self, content, path):
self._content = content
self._path = path
async def read(self):
return self._content
@property
def url(self):
return f"https://naonauno-groundbi-factory.hf.space/guide/{os.path.basename(self._path)}"
reference_image = DummyAttachment(img_file.read(), path)
# Use the generate command with the dummy attachment
await bot.tree.get_command('generate').callback(
interaction,
prompt="highly detailed terrain, realistic, geological features",
reference_image=reference_image,
guidance_scale=1.0,
steps=30,
strength=0.8
)
except FileNotFoundError:
await interaction.response.send_message(f"Image {random_id}.png not found")
except Exception as e:
await interaction.response.send_message(f"Error: {str(e)}")
@bot.tree.command(name="commands", description="Show available commands")
async def commands_command(interaction: discord.Interaction):
help_embed = discord.Embed(
title="๐Ÿž๏ธ GroundBi Bot Commands",
description="Generate your map's GroundBi in just one go!",
color=discord.Color.green()
)
help_embed.add_field(
name="/generate",
value=(
"Generate terrain images with customizable parameters\n"
"**Required Parameters:**\n"
"- `prompt`: Describe the terrain\n"
"- `reference_image`: Image to guide generation\n\n"
"**Optional Parameters:**\n"
"- `negative_prompt`: What to avoid in the image\n"
"- `seed`: Control randomness\n"
"- `guidance_scale`: Prompt adherence (1-20)\n"
"- `guidance_rescale`: Guidance rescaling factor\n"
"- `steps`: Generation quality (10-100)\n"
"- `strength`: Image modification level (0-1)\n"
"- `width`: Output image width (256-2048)\n"
"- `height`: Output image height (256-2048)\n"
"- `num_images`: Number of images (1-4)\n"
"- `eta`: Additional generation parameter\n"
"- `clip_skip`: CLIP skip value\n"
"- `cross_attention_scale`: Cross attention scaling"
),
inline=False
)
help_embed.set_footer(text="Brought to you by Major Platonov himself, for the Pliocene... beep-bop... ๐Ÿค–")
await interaction.response.send_message(embed=help_embed)
@bot.tree.command(name="generate", description="Generate GroundBi images")
@app_commands.describe(
prompt="REQUIRED: Describe the terrain you want to generate",
reference_image="REQUIRED: Reference image to guide generation",
negative_prompt="Optional: Describe what you want to avoid in the image",
seed="Optional: Control randomization (same seed = same output)",
guidance_scale="Optional: How closely to follow the prompt (1-20)",
guidance_rescale="Optional: Rescaling factor for guidance (default: 1.0)",
steps="Optional: Number of diffusion steps (10-100)",
strength="Optional: Amount of image modification (0-1)",
use_custom_size="Optional: Whether to use custom dimensions (default: False)",
width="Optional: Output image width",
height="Optional: Output image height",
num_images="Optional: Number of images to generate",
eta="Optional: Eta parameter for generation",
clip_skip="Optional: CLIP skip value",
cross_attention_scale="Optional: Cross attention scale"
)
async def generate_command(
interaction: discord.Interaction,
prompt: str,
reference_image: discord.Attachment,
negative_prompt: str = "",
seed: int = None,
guidance_scale: float = 1.0,
guidance_rescale: float = 1.0,
steps: int = 25,
strength: float = 0.8,
use_custom_size: bool = False,
width: int = 512,
height: int = 512,
num_images: int = 1,
eta: float = 0.0,
clip_skip: int = 1,
cross_attention_scale: float = 1.0,
):
await interaction.response.defer(ephemeral=False, thinking=True)
try:
start_time = time.time()
image_data = await reference_image.read()
pil_image = Image.open(io.BytesIO(image_data))
if seed is None:
seed = torch.random.initial_seed() % (2**32 - 1)
async def progress_update(step):
if step in [steps//4, steps//2, 3*steps//4]:
progress = int((step/steps) * 100)
await interaction.followup.send(f"Generation progress: {progress}%")
def generation_with_progress():
progress_callback = lambda step, _: asyncio.run_coroutine_threadsafe(
progress_update(step),
bot.loop
)
return bot.generate_image_func(
image=pil_image,
prompt=prompt,
negative_prompt=negative_prompt,
seed=seed,
guidance_scale=guidance_scale,
guidance_rescale=guidance_rescale,
steps=steps,
strength=strength,
num_images=num_images,
use_custom_size=use_custom_size,
output_width=width,
output_height=height,
randomize_seed=False,
progress_callback=progress_callback
)
generated_images, metadata = await asyncio.get_event_loop().run_in_executor(
None,
generation_with_progress
)
generation_time = round(time.time() - start_time, 2)
reference_img_byte_arr = io.BytesIO()
pil_image.thumbnail((200, 200))
pil_image.save(reference_img_byte_arr, format='PNG')
reference_img_byte_arr.seek(0)
reference_file = discord.File(reference_img_byte_arr, filename="reference.png")
image_files = []
for i, (image, _) in enumerate(generated_images):
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
image_files.append(discord.File(
img_byte_arr,
filename=f"terrain_{i+1}_seed_{metadata['seeds'][i]}.png"
))
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU"
embed = discord.Embed(
title="๐Ÿž๏ธ GroundBi Results",
description=f"**Prompt:** {prompt}\n**Negative Prompt:** {negative_prompt}",
color=discord.Color.green()
)
embed.add_field(
name="Generation Details",
value=f"**GPU:** {gpu_name}\n"
f"**Generation Time:** {generation_time} seconds\n"
f"**Model:** GroundBi Terrain Generator",
inline=False
)
embed.add_field(name="Steps", value=str(steps), inline=True)
embed.add_field(name="Guidance Scale", value=str(guidance_scale), inline=True)
embed.add_field(name="Strength", value=str(strength), inline=True)
embed.add_field(name="Dimensions", value=f"{width if width else 'Auto'}x{height if height else 'Auto'}", inline=True)
seeds_text = "\n".join([f"Image {i+1}: {seed}" for i, seed in enumerate(metadata['seeds'])])
embed.add_field(name="Seeds", value=seeds_text, inline=False)
embed.add_field(name="Advanced Settings",
value=f"Guidance Rescale: {guidance_rescale}\n"
f"Cross Attention Scale: {cross_attention_scale}\n"
f"CLIP Skip: {clip_skip}\n"
f"Eta: {eta}",
inline=False)
view = RegenerateView(
prompt=prompt,
negative_prompt=negative_prompt,
reference_image=reference_image,
guidance_scale=guidance_scale,
steps=steps,
strength=strength,
width=width,
height=height,
eta=eta,
clip_skip=clip_skip,
cross_attention_scale=cross_attention_scale,
use_custom_size=use_custom_size,
num_images=num_images,
user_id=interaction.user.id
)
message = await interaction.followup.send(
files=image_files,
embed=embed,
view=view
)
embed.set_thumbnail(url=reference_image.url)
embed.description = f"**Prompt:** {prompt}\n**Negative Prompt:** {negative_prompt}\n\n[Click here to view full-size reference image]({reference_image.url})"
await message.edit(embed=embed)
except torch.cuda.OutOfMemoryError:
await interaction.followup.send("GPU memory exceeded. Try reducing image size or steps.")
except ValueError as ve:
await interaction.followup.send(f"Invalid parameters: {str(ve)}")
except Exception as e:
await interaction.followup.send(f"Error generating terrain: {str(e)}")
finally:
torch.cuda.empty_cache()
gc.collect()
async def setup():
guild = discord.Object(id=int(ALLOWED_SERVER_ID))
bot.tree.copy_global_to(guild=guild)
await bot.tree.sync(guild=guild)
@bot.event
async def on_ready():
await setup()
bot.loop.create_task(cycle_activities())
bot.loop.create_task(queue_manager.process_queue())
await bot.change_presence(
activity=discord.Activity(
type=discord.ActivityType.watching,
name="for /generate commands ๐ŸŽจ"
)
)
print(f"Bot is ready in guild {ALLOWED_SERVER_ID}")
return lambda: bot.run(os.getenv('DISCORD_BOT_TOKEN'))