Spaces:
Paused
Paused
import os | |
import discord | |
from discord import app_commands | |
from discord.ext import commands | |
import torch | |
from PIL import Image | |
import io | |
import asyncio | |
import aiohttp | |
import gc | |
from concurrent.futures import ThreadPoolExecutor | |
import time | |
import random | |
class GenerationQueue: | |
def __init__(self): | |
self.queue = asyncio.Queue() | |
self.active_generations = {} | |
self.position_messages = {} | |
def get_position(self, user_id): | |
position = 1 | |
for interaction, _ in self.queue._queue: | |
if interaction.user.id == user_id: | |
return position | |
position += 1 | |
return 0 | |
async def update_positions(self): | |
current_size = self.queue.qsize() | |
for user_id, msg in self.position_messages.items(): | |
try: | |
new_position = self.get_position(user_id) | |
if new_position > 0: | |
await msg.edit(content=f"Queue position: {new_position}") | |
else: | |
await msg.delete() | |
del self.position_messages[user_id] | |
except: | |
continue | |
async def add_generation(self, interaction, params): | |
position = self.queue.qsize() + 1 | |
await self.queue.put((interaction, params)) | |
if position > 1: | |
msg = await interaction.followup.send(f"Queue position: {position}") | |
self.position_messages[interaction.user.id] = msg | |
return position | |
async def process_queue(self): | |
while True: | |
try: | |
interaction, params = await self.queue.get() | |
if interaction.response.is_done(): | |
self.queue.task_done() | |
continue | |
self.active_generations[interaction.id] = params | |
await interaction.client.tree.get_command('generate').callback( | |
interaction, | |
**params | |
) | |
await self.update_positions() | |
self.queue.task_done() | |
except Exception as e: | |
print(f"Queue processing error: {e}") | |
self.queue.task_done() | |
class RegenerateView(discord.ui.View): | |
def __init__(self, prompt, negative_prompt, reference_image, guidance_scale, steps, strength, width, height, eta, clip_skip, cross_attention_scale, use_custom_size, num_images, user_id): | |
super().__init__(timeout=None) | |
self.params = { | |
'prompt': prompt, | |
'negative_prompt': negative_prompt, | |
'reference_image': reference_image, | |
'guidance_scale': guidance_scale, | |
'steps': steps, | |
'strength': strength, | |
'width': width, | |
'height': height, | |
'eta': eta, | |
'clip_skip': clip_skip, | |
'cross_attention_scale': cross_attention_scale, | |
'use_custom_size': use_custom_size, | |
'num_images': num_images, | |
'user_id': user_id | |
} | |
async def regenerate(self, interaction: discord.Interaction, button: discord.ui.Button): | |
try: | |
ref_image = self.params['reference_image'] | |
command = interaction.client.tree.get_command('generate') | |
params = {k: v for k, v in self.params.items() if k != 'reference_image' and k != 'user_id'} | |
params['seed'] = random.randint(1, 1000000) # Randomize seed for regeneration | |
await command.callback( | |
interaction, | |
reference_image=ref_image, | |
**params | |
) | |
except Exception as e: | |
try: | |
await interaction.followup.send(f"Regeneration failed: {str(e)}", ephemeral=True) | |
except: | |
await interaction.response.send_message(f"Regeneration failed: {str(e)}", ephemeral=True) | |
class GuideImageView(discord.ui.View): | |
def __init__(self): | |
super().__init__(timeout=None) | |
self.current_image = None | |
self.base_url = "https://naonauno-groundbi-factory.hf.space/guide" | |
self.min_id = 1 | |
self.max_id = 6564 # Adjust this to match your actual maximum ID | |
async def get_random_guide_image(self): | |
while True: | |
random_id = random.randint(self.min_id, self.max_id) | |
url = f"{self.base_url}/{random_id}.png" | |
async with aiohttp.ClientSession() as session: | |
async with session.get(url) as response: | |
if response.status == 200: | |
return await response.read(), url | |
elif response.status == 404: | |
continue # Try another random ID if file doesn't exist | |
else: | |
return None, None | |
async def reroll(self, interaction: discord.Interaction, button: discord.ui.Button): | |
await interaction.response.defer() | |
image_data, url = await self.get_random_guide_image() | |
if image_data: | |
self.current_image = image_data | |
embed = discord.Embed(title="๐ฒ Random Guide Image", color=discord.Color.blue()) | |
embed.set_image(url=url) | |
await interaction.followup.send(embed=embed, view=self) | |
async def generate_water(self, interaction: discord.Interaction, button: discord.ui.Button): | |
if self.current_image: | |
await self._generate(interaction, "terrain with water features, lakes, rivers, and coastlines") | |
async def generate_no_water(self, interaction: discord.Interaction, button: discord.ui.Button): | |
if self.current_image: | |
await self._generate(interaction, "dry terrain, mountains, valleys, and deserts") | |
async def _generate(self, interaction, prompt_prefix): | |
if not self.current_image: | |
await interaction.response.send_message("Please select a guide image first!", ephemeral=True) | |
return | |
reference_image = discord.File(io.BytesIO(self.current_image), filename="guide.png") | |
command = interaction.client.tree.get_command('generate') | |
await command.callback( | |
interaction, | |
prompt=f"{prompt_prefix}, highly detailed terrain, realistic, geological features", | |
reference_image=reference_image, | |
guidance_scale=1.0, | |
steps=30, | |
strength=0.8 | |
) | |
def setup_discord_bot(generate_image_func): | |
ALLOWED_SERVER_ID = os.getenv('DISCORD_SERVER_ID') | |
queue_manager = GenerationQueue() | |
intents = discord.Intents.default() | |
intents.message_content = True | |
intents.reactions = True | |
bot = commands.Bot(command_prefix='/', intents=intents) | |
bot.generate_image_func = generate_image_func | |
bot.queue_manager = queue_manager | |
async def cycle_activities(): | |
activities = [ | |
("watching", "for /generate commands ๐จ"), | |
("playing", "with GroundBi maps โฐ๏ธ"), | |
("listening", "to your prompts ๐ง") | |
] | |
while True: | |
for activity_type, name in activities: | |
activity = discord.Activity( | |
type=getattr(discord.ActivityType, activity_type), | |
name=name | |
) | |
await bot.change_presence(activity=activity) | |
await asyncio.sleep(50) | |
async def quickgen_command(interaction: discord.Interaction): | |
await interaction.response.defer() | |
try: | |
view = GuideImageView() # Remove arguments | |
image_data, _ = await view.get_random_guide_image() | |
if image_data: | |
reference_image = discord.File(io.BytesIO(image_data), filename="guide.png") | |
await bot.tree.get_command('generate').callback( | |
interaction, | |
prompt="highly detailed terrain, realistic, geological features", | |
reference_image=reference_image, | |
guidance_scale=1.0, | |
steps=30, | |
strength=0.8 | |
) | |
else: | |
await interaction.followup.send("Failed to fetch guide image") | |
except Exception as e: | |
await interaction.followup.send(f"Error: {str(e)}") | |
async def guidegen_command(interaction: discord.Interaction): | |
await interaction.response.defer() | |
try: | |
view = GuideImageView() # Remove arguments | |
image_data, url = await view.get_random_guide_image() | |
if image_data: | |
view.current_image = image_data | |
embed = discord.Embed(title="๐ฒ Random Guide Image", color=discord.Color.blue()) | |
embed.set_image(url=url) | |
await interaction.followup.send(embed=embed, view=view) | |
else: | |
await interaction.followup.send("Failed to fetch guide image") | |
except Exception as e: | |
await interaction.followup.send(f"Error: {str(e)}") | |
class GuideRerollView(discord.ui.View): | |
def __init__(self): | |
super().__init__(timeout=None) | |
async def reroll(self, interaction: discord.Interaction, button: discord.ui.Button): | |
try: | |
random_id = random.randint(1, 6564) | |
path = f"guide/{random_id}.png" | |
file = discord.File(path, filename=f"{random_id}.png") | |
await interaction.response.edit_message(attachments=[file], view=self) | |
except Exception as e: | |
await interaction.response.send_message(f"Error: {str(e)}", ephemeral=True) | |
async def dev1_command(interaction: discord.Interaction): | |
await interaction.response.defer() | |
try: | |
random_id = random.randint(1, 6564) | |
path = f"guide/{random_id}.png" | |
with open(path, 'rb') as img: | |
file = discord.File(img, filename=f"{random_id}.png") | |
view = GuideRerollView() | |
await interaction.followup.send(file=file, view=view) | |
except Exception as e: | |
await interaction.followup.send(f"Error: {str(e)}") | |
async def dev2_command(interaction: discord.Interaction): | |
try: | |
random_id = random.randint(1, 6564) | |
path = f"guide/{random_id}.png" | |
# Read the file content | |
with open(path, 'rb') as img_file: | |
# Create a dummy attachment-like object with read method | |
class DummyAttachment: | |
def __init__(self, content, path): | |
self._content = content | |
self._path = path | |
async def read(self): | |
return self._content | |
def url(self): | |
return f"https://naonauno-groundbi-factory.hf.space/guide/{os.path.basename(self._path)}" | |
reference_image = DummyAttachment(img_file.read(), path) | |
# Use the generate command with the dummy attachment | |
await bot.tree.get_command('generate').callback( | |
interaction, | |
prompt="highly detailed terrain, realistic, geological features", | |
reference_image=reference_image, | |
guidance_scale=1.0, | |
steps=30, | |
strength=0.8 | |
) | |
except FileNotFoundError: | |
await interaction.response.send_message(f"Image {random_id}.png not found") | |
except Exception as e: | |
await interaction.response.send_message(f"Error: {str(e)}") | |
async def commands_command(interaction: discord.Interaction): | |
help_embed = discord.Embed( | |
title="๐๏ธ GroundBi Bot Commands", | |
description="Generate your map's GroundBi in just one go!", | |
color=discord.Color.green() | |
) | |
help_embed.add_field( | |
name="/generate", | |
value=( | |
"Generate terrain images with customizable parameters\n" | |
"**Required Parameters:**\n" | |
"- `prompt`: Describe the terrain\n" | |
"- `reference_image`: Image to guide generation\n\n" | |
"**Optional Parameters:**\n" | |
"- `negative_prompt`: What to avoid in the image\n" | |
"- `seed`: Control randomness\n" | |
"- `guidance_scale`: Prompt adherence (1-20)\n" | |
"- `guidance_rescale`: Guidance rescaling factor\n" | |
"- `steps`: Generation quality (10-100)\n" | |
"- `strength`: Image modification level (0-1)\n" | |
"- `width`: Output image width (256-2048)\n" | |
"- `height`: Output image height (256-2048)\n" | |
"- `num_images`: Number of images (1-4)\n" | |
"- `eta`: Additional generation parameter\n" | |
"- `clip_skip`: CLIP skip value\n" | |
"- `cross_attention_scale`: Cross attention scaling" | |
), | |
inline=False | |
) | |
help_embed.set_footer(text="Brought to you by Major Platonov himself, for the Pliocene... beep-bop... ๐ค") | |
await interaction.response.send_message(embed=help_embed) | |
async def generate_command( | |
interaction: discord.Interaction, | |
prompt: str, | |
reference_image: discord.Attachment, | |
negative_prompt: str = "", | |
seed: int = None, | |
guidance_scale: float = 1.0, | |
guidance_rescale: float = 1.0, | |
steps: int = 25, | |
strength: float = 0.8, | |
use_custom_size: bool = False, | |
width: int = 512, | |
height: int = 512, | |
num_images: int = 1, | |
eta: float = 0.0, | |
clip_skip: int = 1, | |
cross_attention_scale: float = 1.0, | |
): | |
await interaction.response.defer(ephemeral=False, thinking=True) | |
try: | |
start_time = time.time() | |
image_data = await reference_image.read() | |
pil_image = Image.open(io.BytesIO(image_data)) | |
if seed is None: | |
seed = torch.random.initial_seed() % (2**32 - 1) | |
async def progress_update(step): | |
if step in [steps//4, steps//2, 3*steps//4]: | |
progress = int((step/steps) * 100) | |
await interaction.followup.send(f"Generation progress: {progress}%") | |
def generation_with_progress(): | |
progress_callback = lambda step, _: asyncio.run_coroutine_threadsafe( | |
progress_update(step), | |
bot.loop | |
) | |
return bot.generate_image_func( | |
image=pil_image, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
seed=seed, | |
guidance_scale=guidance_scale, | |
guidance_rescale=guidance_rescale, | |
steps=steps, | |
strength=strength, | |
num_images=num_images, | |
use_custom_size=use_custom_size, | |
output_width=width, | |
output_height=height, | |
randomize_seed=False, | |
progress_callback=progress_callback | |
) | |
generated_images, metadata = await asyncio.get_event_loop().run_in_executor( | |
None, | |
generation_with_progress | |
) | |
generation_time = round(time.time() - start_time, 2) | |
reference_img_byte_arr = io.BytesIO() | |
pil_image.thumbnail((200, 200)) | |
pil_image.save(reference_img_byte_arr, format='PNG') | |
reference_img_byte_arr.seek(0) | |
reference_file = discord.File(reference_img_byte_arr, filename="reference.png") | |
image_files = [] | |
for i, (image, _) in enumerate(generated_images): | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format='PNG') | |
img_byte_arr.seek(0) | |
image_files.append(discord.File( | |
img_byte_arr, | |
filename=f"terrain_{i+1}_seed_{metadata['seeds'][i]}.png" | |
)) | |
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU" | |
embed = discord.Embed( | |
title="๐๏ธ GroundBi Results", | |
description=f"**Prompt:** {prompt}\n**Negative Prompt:** {negative_prompt}", | |
color=discord.Color.green() | |
) | |
embed.add_field( | |
name="Generation Details", | |
value=f"**GPU:** {gpu_name}\n" | |
f"**Generation Time:** {generation_time} seconds\n" | |
f"**Model:** GroundBi Terrain Generator", | |
inline=False | |
) | |
embed.add_field(name="Steps", value=str(steps), inline=True) | |
embed.add_field(name="Guidance Scale", value=str(guidance_scale), inline=True) | |
embed.add_field(name="Strength", value=str(strength), inline=True) | |
embed.add_field(name="Dimensions", value=f"{width if width else 'Auto'}x{height if height else 'Auto'}", inline=True) | |
seeds_text = "\n".join([f"Image {i+1}: {seed}" for i, seed in enumerate(metadata['seeds'])]) | |
embed.add_field(name="Seeds", value=seeds_text, inline=False) | |
embed.add_field(name="Advanced Settings", | |
value=f"Guidance Rescale: {guidance_rescale}\n" | |
f"Cross Attention Scale: {cross_attention_scale}\n" | |
f"CLIP Skip: {clip_skip}\n" | |
f"Eta: {eta}", | |
inline=False) | |
view = RegenerateView( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
reference_image=reference_image, | |
guidance_scale=guidance_scale, | |
steps=steps, | |
strength=strength, | |
width=width, | |
height=height, | |
eta=eta, | |
clip_skip=clip_skip, | |
cross_attention_scale=cross_attention_scale, | |
use_custom_size=use_custom_size, | |
num_images=num_images, | |
user_id=interaction.user.id | |
) | |
message = await interaction.followup.send( | |
files=image_files, | |
embed=embed, | |
view=view | |
) | |
embed.set_thumbnail(url=reference_image.url) | |
embed.description = f"**Prompt:** {prompt}\n**Negative Prompt:** {negative_prompt}\n\n[Click here to view full-size reference image]({reference_image.url})" | |
await message.edit(embed=embed) | |
except torch.cuda.OutOfMemoryError: | |
await interaction.followup.send("GPU memory exceeded. Try reducing image size or steps.") | |
except ValueError as ve: | |
await interaction.followup.send(f"Invalid parameters: {str(ve)}") | |
except Exception as e: | |
await interaction.followup.send(f"Error generating terrain: {str(e)}") | |
finally: | |
torch.cuda.empty_cache() | |
gc.collect() | |
async def setup(): | |
guild = discord.Object(id=int(ALLOWED_SERVER_ID)) | |
bot.tree.copy_global_to(guild=guild) | |
await bot.tree.sync(guild=guild) | |
async def on_ready(): | |
await setup() | |
bot.loop.create_task(cycle_activities()) | |
bot.loop.create_task(queue_manager.process_queue()) | |
await bot.change_presence( | |
activity=discord.Activity( | |
type=discord.ActivityType.watching, | |
name="for /generate commands ๐จ" | |
) | |
) | |
print(f"Bot is ready in guild {ALLOWED_SERVER_ID}") | |
return lambda: bot.run(os.getenv('DISCORD_BOT_TOKEN')) |