Spaces:
Running
Running
import os | |
import re | |
from dotenv import load_dotenv, find_dotenv | |
import json | |
import gradio as gr | |
import torch # first import torch then transformers | |
from torch.nn.functional import softmax | |
from transformers import AutoModelForSequenceClassification | |
from huggingface_hub import InferenceClient | |
from transformers import pipeline | |
from huggingface_hub import login | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import logging | |
import sys | |
from datetime import datetime | |
import psutil | |
from typing import Dict, Any, Optional, Tuple | |
# # Add model caching and optimization | |
# from functools import lru_cache | |
# import torch.nn as nn | |
# Custom tprint function with timestamp | |
def tprint(*args, **kwargs): | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
print(f"[{timestamp}] [{sys._getframe().f_back.f_lineno}]", *args, **kwargs) | |
# Configure logging with timestamp and line numbers | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
logger = logging.getLogger(__name__) | |
def get_available_memory(): | |
"""Get available GPU and system memory""" | |
gpu_memory = None | |
if torch.cuda.is_available(): | |
gpu_memory = torch.cuda.get_device_properties(0).total_memory | |
system_memory = psutil.virtual_memory().available | |
return gpu_memory, system_memory | |
def load_env(): | |
_ = load_dotenv(find_dotenv()) | |
def get_huggingface_api_key(): | |
load_env() | |
huggingface_api_key = os.getenv("HUGGINGFACE_API_KEY") | |
if not huggingface_api_key: | |
logging.error("HUGGINGFACE_API_KEY not found in environment variables") | |
raise ValueError("HUGGINGFACE_API_KEY not found in environment variables") | |
return huggingface_api_key | |
def get_huggingface_inference_key(): | |
load_env() | |
huggingface_inference_key = os.getenv("HUGGINGFACE_INFERENCE_KEY") | |
if not huggingface_inference_key: | |
logging.error("HUGGINGFACE_API_KEY not found in environment variables") | |
raise ValueError("HUGGINGFACE_API_KEY not found in environment variables") | |
return huggingface_inference_key | |
# Model configuration | |
MODEL_CONFIG = { | |
"main_model": { | |
# "name": "meta-llama/Llama-3.2-3B-Instruct", | |
# "name": "meta-llama/Llama-3.2-1B-Instruct", # to fit in cpu on hugging face space | |
"name": "meta-llama/Llama-3.2-1B", # to fit in cpu on hugging face space | |
# "name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # to fit in cpu on hugging face space | |
# "name": "microsoft/phi-2", | |
# "dtype": torch.bfloat16, | |
"dtype": torch.float32, # Use float32 for CPU | |
"max_length": 512, | |
"device": "cuda" if torch.cuda.is_available() else "cpu", | |
}, | |
"safety_model": { | |
"name": "meta-llama/Llama-Guard-3-1B", | |
# "dtype": torch.bfloat16, | |
"dtype": torch.float32, # Use float32 for CPU | |
"max_length": 256, | |
"device": "cuda" if torch.cuda.is_available() else "cpu", | |
"max_tokens": 500, | |
}, | |
} | |
PROMPT_GUARD_CONFIG = { | |
"model_id": "meta-llama/Prompt-Guard-86M", | |
"temperature": 1.0, | |
"jailbreak_threshold": 0.5, | |
"injection_threshold": 0.9, | |
"device": "cpu", | |
"safe_commands": [ | |
"look around", | |
"investigate", | |
"explore", | |
"search", | |
"examine", | |
"take", | |
"use", | |
"go", | |
"walk", | |
"continue", | |
"help", | |
"inventory", | |
"quest", | |
"status", | |
"map", | |
"talk", | |
"fight", | |
"run", | |
"hide", | |
], | |
"max_length": 512, | |
} | |
def initialize_prompt_guard(): | |
"""Initialize Prompt Guard model""" | |
try: | |
api_key = get_huggingface_api_key() | |
login(token=api_key) | |
tokenizer = AutoTokenizer.from_pretrained(PROMPT_GUARD_CONFIG["model_id"]) | |
model = AutoModelForSequenceClassification.from_pretrained( | |
PROMPT_GUARD_CONFIG["model_id"] | |
) | |
return model, tokenizer | |
except Exception as e: | |
logger.error(f"Failed to initialize Prompt Guard: {e}") | |
raise | |
def get_class_probabilities(text: str, guard_model, guard_tokenizer) -> torch.Tensor: | |
"""Evaluate model probabilities with temperature scaling""" | |
try: | |
inputs = guard_tokenizer( | |
text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=PROMPT_GUARD_CONFIG["max_length"], | |
).to(PROMPT_GUARD_CONFIG["device"]) | |
with torch.no_grad(): | |
logits = guard_model(**inputs).logits | |
scaled_logits = logits / PROMPT_GUARD_CONFIG["temperature"] | |
return softmax(scaled_logits, dim=-1) | |
except Exception as e: | |
logger.error(f"Error getting class probabilities: {e}") | |
return None | |
def get_jailbreak_score(text: str, guard_model, guard_tokenizer) -> float: | |
"""Get jailbreak probability score""" | |
try: | |
probabilities = get_class_probabilities(text, guard_model, guard_tokenizer) | |
if probabilities is None: | |
return 1.0 # Fail safe | |
return probabilities[0, 2].item() | |
except Exception as e: | |
logger.error(f"Error getting jailbreak score: {e}") | |
return 1.0 | |
def get_injection_score(text: str, guard_model, guard_tokenizer) -> float: | |
"""Get injection probability score""" | |
try: | |
probabilities = get_class_probabilities(text, guard_model, guard_tokenizer) | |
if probabilities is None: | |
return 1.0 # Fail safe | |
return (probabilities[0, 1] + probabilities[0, 2]).item() | |
except Exception as e: | |
logger.error(f"Error getting injection score: {e}") | |
return 1.0 | |
# Initialize safety model pipeline | |
try: | |
# Initialize Prompt Guard | |
guard_model, guard_tokenizer = initialize_prompt_guard() | |
except Exception as e: | |
logger.error(f"Failed to initialize model: {str(e)}") | |
def is_prompt_safe(message: str) -> bool: | |
"""Enhanced safety check with Prompt Guard""" | |
try: | |
# Allow safe game commands | |
if any(cmd in message.lower() for cmd in PROMPT_GUARD_CONFIG["safe_commands"]): | |
logger.info("Message matched safe command pattern") | |
return True | |
# Get safety scores | |
jailbreak_score = get_jailbreak_score(message, guard_model, guard_tokenizer) | |
injection_score = get_injection_score(message, guard_model, guard_tokenizer) | |
logger.info( | |
f"Safety scores - Jailbreak: {jailbreak_score}, Injection: {injection_score}" | |
) | |
# Check against thresholds | |
is_safe = ( | |
jailbreak_score | |
< PROMPT_GUARD_CONFIG["jailbreak_threshold"] | |
# and injection_score < PROMPT_GUARD_CONFIG["injection_threshold"] # Disable for now because injection is too strict and current prompt guard model seems malfunctioning for now. | |
) | |
logger.info(f"Final safety result: {is_safe}") | |
return is_safe | |
except Exception as e: | |
logger.error(f"Safety check failed: {e}") | |
return False | |
# def initialize_model_pipeline(model_name, force_cpu=False): | |
# """Initialize pipeline with memory management""" | |
# try: | |
# if force_cpu: | |
# device = -1 | |
# else: | |
# device = MODEL_CONFIG["main_model"]["device"] | |
# api_key = get_huggingface_api_key() | |
# # Use 8-bit quantization for memory efficiency | |
# model = AutoModelForCausalLM.from_pretrained( | |
# model_name, | |
# load_in_8bit=False, | |
# torch_dtype=MODEL_CONFIG["main_model"]["dtype"], | |
# use_cache=True, | |
# device_map="auto", | |
# low_cpu_mem_usage=True, | |
# trust_remote_code=True, | |
# token=api_key, # Add token here | |
# ) | |
# model.config.use_cache = True | |
# tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_key) | |
# # Initialize pipeline | |
# logger.info(f"Initializing pipeline with device: {device}") | |
# generator = pipeline( | |
# "text-generation", | |
# model=model, | |
# tokenizer=tokenizer, | |
# # device=device, | |
# # temperature=0.7, | |
# model_kwargs={"low_cpu_mem_usage": True}, | |
# ) | |
# logger.info("Model Pipeline initialized successfully") | |
# return generator, tokenizer | |
# except ImportError as e: | |
# logger.error(f"Missing required package: {str(e)}") | |
# raise | |
# except Exception as e: | |
# logger.error(f"Failed to initialize pipeline: {str(e)}") | |
# raise | |
# # Initialize model pipeline | |
# try: | |
# # Use a smaller model for testing | |
# # model_name = "meta-llama/Meta-Llama-3-8B-Instruct" | |
# # model_name = "google/gemma-2-2b" # Start with a smaller model | |
# # model_name = "microsoft/phi-2" | |
# # model_name = "meta-llama/Llama-3.2-1B-Instruct" | |
# # model_name = "meta-llama/Llama-3.2-3B-Instruct" | |
# model_name = MODEL_CONFIG["main_model"]["name"] | |
# # Initialize the pipeline with memory management | |
# generator, tokenizer = initialize_model_pipeline(model_name) | |
# except Exception as e: | |
# logger.error(f"Failed to initialize model: {str(e)}") | |
# # Fallback to CPU if GPU initialization fails | |
# try: | |
# logger.info("Attempting CPU fallback...") | |
# generator, tokenizer = initialize_model_pipeline(model_name, force_cpu=True) | |
# except Exception as e: | |
# logger.error(f"CPU fallback failed: {str(e)}") | |
# raise | |
def initialize_inference_client(): | |
"""Initialize HuggingFace Inference Client""" | |
try: | |
inference_key = get_huggingface_inference_key() | |
client = InferenceClient(api_key=inference_key) | |
logger.info("Inference Client initialized successfully") | |
return client | |
except Exception as e: | |
logger.error(f"Failed to initialize Inference Client: {e}") | |
raise | |
# Initialize inference client and make API call | |
try: | |
inference_client = initialize_inference_client() | |
except Exception as e: | |
logger.error(f"Failed to initialize the inference client model: {str(e)}") | |
def load_world(filename): | |
with open(filename, "r") as f: | |
return json.load(f) | |
# Define system_prompt and model | |
system_prompt = """You are an AI Game master. Your job is to write what happens next in a player's adventure game. | |
CRITICAL Rules: | |
- Write EXACTLY 3 sentences maximum | |
- Use daily English language | |
- Start with "You " | |
- Don't use 'Elara' or 'she/he', only use 'you' | |
- Use only second person ("you") | |
- Never include dialogue after the response | |
- Never continue with additional actions or responses | |
- Never add follow-up questions or choices | |
- Never include 'User:' or 'Assistant:' in response | |
- Never include any note or these kinds of sentences: 'Note from the game master' | |
- Never use ellipsis (...) | |
- Never include 'What would you like to do?' or similar prompts | |
- Always finish with one real response | |
- Never use 'Your turn' or or anything like conversation starting prompts | |
- Always end the response with a period(.)""" | |
def get_game_state(inventory: Dict = None) -> Dict[str, Any]: | |
"""Initialize game state with safe defaults and quest system""" | |
try: | |
# Load world data | |
world = load_world("shared_data/Ethoria.json") | |
character = world["kingdoms"]["Valdor"]["towns"]["Ravenhurst"]["npcs"][ | |
"Elara Brightshield" | |
] | |
tprint(f"character in get_game_state: {character}") | |
game_state = { | |
"name": world["name"], | |
"world": world["description"], | |
"kingdom": world["kingdoms"]["Valdor"]["description"], | |
"town_name": world["kingdoms"]["Valdor"]["towns"]["Ravenhurst"]["name"], | |
"town": world["kingdoms"]["Valdor"]["towns"]["Ravenhurst"]["description"], | |
"character_name": character["name"], | |
"character_description": character["description"], | |
"start": world["start"], | |
"inventory": inventory | |
or { | |
"cloth pants": 1, | |
"cloth shirt": 1, | |
"goggles": 1, | |
"leather bound journal": 1, | |
"gold": 5, | |
}, | |
"player": None, | |
"dungeon": None, | |
"current_quest": None, | |
"completed_quests": [], | |
"exp": 0, | |
"level": 1, | |
"reputation": {"Valdor": 0, "Ravenhurst": 0}, | |
} | |
# tprint(f"game_state in get_game_state: {game_state}") | |
# Extract required data with fallbacks | |
return game_state | |
except (FileNotFoundError, KeyError, json.JSONDecodeError) as e: | |
logger.error(f"Error loading world data: {e}") | |
# Provide default values if world loading fails | |
return { | |
"world": "Ethoria is a realm of seven kingdoms, each founded on distinct moral principles.", | |
"kingdom": "Valdor, the Kingdom of Courage", | |
"town": "Ravenhurst, a town of skilled hunters and trappers", | |
"character_name": "Elara Brightshield", | |
"character_description": "A sturdy warrior with shining silver armor", | |
"start": "Your journey begins in the mystical realm of Ethoria...", | |
"inventory": inventory | |
or { | |
"cloth pants": 1, | |
"cloth shirt": 1, | |
"goggles": 1, | |
"leather bound journal": 1, | |
"gold": 5, | |
}, | |
"player": None, | |
"dungeon": None, | |
"current_quest": None, | |
"completed_quests": [], | |
"exp": 0, | |
"level": 1, | |
"reputation": {"Valdor": 0, "Ravenhurst": 0}, | |
} | |
def generate_dynamic_quest(game_state: Dict) -> Dict: | |
"""Generate varied quests based on progress and level""" | |
completed = len(game_state.get("completed_quests", [])) | |
level = game_state.get("level", 1) | |
# Quest templates by type | |
quest_types = { | |
"combat": [ | |
{ | |
"title": "The Beast's Lair", | |
"description": "A fearsome {creature} has been terrorizing the outskirts of Ravenhurst.", | |
"objective": "Hunt down and defeat the {creature}.", | |
"creatures": [ | |
"shadow wolf", | |
"frost bear", | |
"ancient wyrm", | |
"spectral tiger", | |
], | |
}, | |
], | |
"exploration": [ | |
{ | |
"title": "Lost Secrets", | |
"description": "Rumors speak of an ancient {location} containing powerful artifacts.", | |
"objective": "Explore the {location} and uncover its secrets.", | |
"locations": [ | |
"crypt", | |
"temple ruins", | |
"abandoned mine", | |
"forgotten library", | |
], | |
}, | |
], | |
"mystery": [ | |
{ | |
"title": "Dark Omens", | |
"description": "The {sign} has appeared, marking the rise of an ancient power.", | |
"objective": "Investigate the meaning of the {sign}.", | |
"signs": [ | |
"blood moon", | |
"mysterious runes", | |
"spectral lights", | |
"corrupted wildlife", | |
], | |
}, | |
], | |
} | |
# Select quest type and template | |
quest_type = list(quest_types.keys())[completed % len(quest_types)] | |
template = quest_types[quest_type][0] # Could add more templates per type | |
# Fill in dynamic elements | |
if quest_type == "combat": | |
creature = template["creatures"][level % len(template["creatures"])] | |
title = template["title"] | |
description = template["description"].format(creature=creature) | |
objective = template["objective"].format(creature=creature) | |
elif quest_type == "exploration": | |
location = template["locations"][level % len(template["locations"])] | |
title = template["title"] | |
description = template["description"].format(location=location) | |
objective = template["objective"].format(location=location) | |
else: # mystery | |
sign = template["signs"][level % len(template["signs"])] | |
title = template["title"] | |
description = template["description"].format(sign=sign) | |
objective = template["objective"].format(sign=sign) | |
return { | |
"id": f"quest_{quest_type}_{completed}", | |
"title": title, | |
"description": f"{description} {objective}", | |
"exp_reward": 150 + (level * 50), | |
"status": "active", | |
"triggers": ["investigate", "explore", quest_type, "search"], | |
"completion_text": f"You've made progress in understanding the growing darkness.", | |
"next_quest_hint": "More mysteries await in the shadows of Ravenhurst.", | |
} | |
def generate_next_quest(game_state: Dict) -> Dict: | |
"""Generate next quest based on progress""" | |
completed = len(game_state.get("completed_quests", [])) | |
level = game_state.get("level", 1) | |
quest_chain = [ | |
{ | |
"id": "mist_investigation", | |
"title": "Investigate the Mist", | |
"description": "Strange mists have been gathering around Ravenhurst. Investigate their source.", | |
"exp_reward": 100, | |
"status": "active", | |
"triggers": ["mist", "investigate", "explore"], | |
"completion_text": "As you investigate the mist, you discover ancient runes etched into nearby stones.", | |
"next_quest_hint": "The runes seem to point to an old hunting trail.", | |
}, | |
{ | |
"id": "hunters_trail", | |
"title": "The Hunter's Trail", | |
"description": "Local hunters have discovered strange tracks in the forest. Follow them to their source.", | |
"exp_reward": 150, | |
"status": "active", | |
"triggers": ["tracks", "follow", "trail"], | |
"completion_text": "The tracks lead to an ancient well, where you hear strange whispers.", | |
"next_quest_hint": "The whispers seem to be coming from deep within the well.", | |
}, | |
{ | |
"id": "dark_whispers", | |
"title": "Whispers in the Dark", | |
"description": "Mysterious whispers echo from the old well. Investigate their source.", | |
"exp_reward": 200, | |
"status": "active", | |
"triggers": ["well", "whispers", "listen"], | |
"completion_text": "You discover an ancient seal at the bottom of the well.", | |
"next_quest_hint": "The seal bears markings of an ancient evil.", | |
}, | |
] | |
# Generate dynamic quests after initial chain | |
if completed >= len(quest_chain): | |
return generate_dynamic_quest(game_state) | |
# current_quest_index = min(completed, len(quest_chain) - 1) | |
# return quest_chain[current_quest_index] | |
return quest_chain[completed] | |
def check_quest_completion(message: str, game_state: Dict) -> Tuple[bool, str]: | |
"""Check quest completion and handle progression""" | |
if not game_state.get("current_quest"): | |
return False, "" | |
quest = game_state["current_quest"] | |
triggers = quest.get("triggers", []) | |
if any(trigger in message.lower() for trigger in triggers): | |
# Award experience | |
exp_reward = quest.get("exp_reward", 100) | |
game_state["exp"] += exp_reward | |
# Update player level if needed | |
while game_state["exp"] >= 100 * game_state["level"]: | |
game_state["level"] += 1 | |
game_state["player"].level = ( | |
game_state["level"] if game_state.get("player") else game_state["level"] | |
) | |
level_up_text = ( | |
f"\nLevel Up! You are now level {game_state['level']}!" | |
if game_state["exp"] >= 100 * (game_state["level"] - 1) | |
else "" | |
) | |
# Store completed quest | |
game_state["completed_quests"].append(quest) | |
# Generate next quest | |
next_quest = generate_next_quest(game_state) | |
game_state["current_quest"] = next_quest | |
# Update status display | |
if game_state.get("player"): | |
game_state["player"].exp = game_state["exp"] | |
game_state["player"].level = game_state["level"] | |
# Build completion message | |
completion_msg = f""" | |
Quest Complete: {quest['title']}! (+{exp_reward} exp){level_up_text} | |
{quest.get('completion_text', '')} | |
New Quest: {next_quest['title']} | |
{next_quest['description']} | |
{next_quest.get('next_quest_hint', '')}""" | |
return True, completion_msg | |
return False, "" | |
def parse_items_from_story(text: str) -> Dict[str, int]: | |
"""Extract item changes from story text with improved pattern matching""" | |
items = {} | |
# Skip parsing if text starts with common narrative phrases | |
skip_patterns = [ | |
"you see", | |
"you find yourself", | |
"you are", | |
"you stand", | |
"you hear", | |
"you feel", | |
] | |
if any(text.lower().startswith(pattern) for pattern in skip_patterns): | |
return items | |
# Common item keywords and patterns | |
gold_pattern = r"(\d+)\s*gold(?:\s+coins?)?" | |
items_pattern = r"(?:receive|find|given|obtain|pick up|grab)\s+(?:a|an|the)?\s*(\d+)?\s*([\w\s]+?)" | |
try: | |
# Find gold amounts | |
gold_matches = re.findall(gold_pattern, text.lower()) | |
if gold_matches: | |
items["gold"] = sum(int(x) for x in gold_matches) | |
# Find other items | |
item_matches = re.findall(items_pattern, text.lower()) | |
for count, item in item_matches: | |
# Validate item name | |
item = item.strip() | |
if len(item) > 2 and not any( # Minimum length check | |
skip in item for skip in ["yourself", "you", "door", "wall", "floor"] | |
): # Skip common words | |
count = int(count) if count else 1 | |
if item in items: | |
items[item] += count | |
else: | |
items[item] = count | |
return items | |
except Exception as e: | |
logger.error(f"Error parsing items from story: {e}") | |
return {} | |
def update_game_inventory(game_state: Dict, story_text: str) -> Tuple[str, list]: | |
"""Update inventory and return message and updated inventory data""" | |
try: | |
items = parse_items_from_story(story_text) | |
update_msg = "" | |
# Update inventory | |
for item, count in items.items(): | |
if item in game_state["inventory"]: | |
game_state["inventory"][item] += count | |
else: | |
game_state["inventory"][item] = count | |
update_msg += f"\nReceived: {count} {item}" | |
# Create updated inventory data for display | |
inventory_data = [ | |
[item, count] for item, count in game_state["inventory"].items() | |
] | |
return update_msg, inventory_data | |
except Exception as e: | |
logger.error(f"Error updating inventory: {e}") | |
return "", [] | |
def extract_response_after_action(full_text: str, action: str) -> str: | |
"""Extract response text that comes after the user action line""" | |
try: | |
if not full_text: # Add null check | |
logger.error("Received empty response from model") | |
return "You look around carefully." | |
# Split into lines | |
lines = full_text.split("\n") | |
# Find index of line containing user action | |
action_line_index = -1 | |
for i, line in enumerate(lines): | |
if action.lower() in line.lower(): # More flexible matching | |
action_line_index = i | |
break | |
if action_line_index >= 0: | |
# Get all lines after the action line | |
response_lines = lines[action_line_index + 1 :] | |
response = " ".join(line.strip() for line in response_lines if line.strip()) | |
# Clean up any remaining markers | |
response = response.split("user:")[0].strip() | |
response = response.split("system:")[0].strip() | |
response = response.split("assistant:")[0].strip() | |
return response if response else "You look around carefully." | |
return "You look around carefully." # Default response | |
except Exception as e: | |
logger.error(f"Error extracting response: {e}") | |
return "You look around carefully." | |
def run_action(message: str, history: list, game_state: Dict) -> str: | |
"""Process game actions and generate responses with quest handling""" | |
try: | |
initial_quest = generate_next_quest(game_state) | |
game_state["current_quest"] = initial_quest | |
# Handle start game command | |
if message.lower() == "start game": | |
start_response = f"""Welcome to {game_state['name']}. {game_state['world']} | |
{game_state['start']} | |
You are currently in {game_state['town_name']}, {game_state['town']}. | |
{game_state['town_name']} is a city in {game_state['kingdom']}. | |
Current Quest: {initial_quest['title']} | |
{initial_quest['description']} | |
What would you like to do?""" | |
return start_response | |
# Verify game state | |
if not isinstance(game_state, dict): | |
logger.error(f"Invalid game state type: {type(game_state)}") | |
return "Error: Invalid game state" | |
# Safety check with Prompt Guard | |
if not is_prompt_safe(message): | |
logger.warning("Unsafe content detected in user prompt") | |
return "I cannot process that request for safety reasons." | |
# logger.info(f"Processing action with game state: {game_state}") | |
logger.info(f"Processing action with game state") | |
world_info = f"""World: {game_state['world']} | |
Kingdom: {game_state['kingdom']} | |
Town: {game_state['town']} | |
Character: {game_state['character_name']} | |
Current Quest: {game_state["current_quest"]['title']} | |
Quest Objective: {game_state["current_quest"]['description']} | |
Inventory: {json.dumps(game_state['inventory'])}""" | |
# # Enhanced system prompt for better response formatting | |
# enhanced_prompt = f"""{system_prompt} | |
# Additional Rules: | |
# - Always start responses with 'You ', 'You see' or 'You hear' or 'You feel' | |
# - Use ONLY second person perspective ('you', not 'Elara' or 'she/he') | |
# - Describe immediate surroundings and sensations | |
# - Keep responses focused on the player's direct experience""" | |
# messages = [ | |
# {"role": "system", "content": system_prompt}, | |
# {"role": "user", "content": world_info}, | |
# ] | |
# Properly formatted messages for API | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": world_info}, | |
{ | |
"role": "assistant", | |
"content": "I understand the game world and will help guide your adventure.", | |
}, | |
{"role": "user", "content": message}, | |
] | |
# # Format chat history | |
# if history: | |
# for h in history: | |
# if isinstance(h, tuple): | |
# messages.append({"role": "assistant", "content": h[0]}) | |
# messages.append({"role": "user", "content": h[1]}) | |
# Add history in correct alternating format | |
if history: | |
# for h in history[-3:]: # Last 3 exchanges | |
for h in history: | |
if isinstance(h, tuple): | |
messages.append({"role": "user", "content": h[0]}) | |
messages.append({"role": "assistant", "content": h[1]}) | |
# messages.append({"role": "user", "content": message}) | |
# Convert messages to string format for pipeline | |
prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) | |
logger.info("Generating response...") | |
## Generate response | |
# model_output = generator( | |
# prompt, | |
# max_new_tokens=len(tokenizer.encode(message)) | |
# + 120, # Set max_new_tokens based on input length | |
# num_return_sequences=1, | |
# # temperature=0.7, # More creative but still focused | |
# repetition_penalty=1.2, | |
# pad_token_id=tokenizer.eos_token_id, | |
# ) | |
# # Check for None response | |
# if not model_output or not isinstance(model_output, list): | |
# logger.error(f"Invalid model output: {model_output}") | |
# tprint(f"Invalid model output: {model_output}") | |
# return "You look around carefully." | |
# if not model_output[0] or not isinstance(model_output[0], dict): | |
# logger.error(f"Invalid response format: {type(model_output[0])}") | |
# return "You look around carefully." | |
# # Extract and clean response | |
# full_response = model_output[0]["generated_text"] | |
# if not full_response: | |
# logger.error("Empty response from model") | |
# return "You look around carefully." | |
# tprint(f"Full response in run_action: {full_response}") | |
# response = extract_response_after_action(full_response, message) | |
# tprint(f"Extracted response in run_action: {response}") | |
# # Convert to second person | |
# response = response.replace("Elara", "You") | |
# # # Format response | |
# # if not response.startswith("You"): | |
# # response = "You see " + response | |
# # Validate no cut-off sentences | |
# if response.rstrip().endswith(("you also", "meanwhile", "suddenly", "...")): | |
# response = response.rsplit(" ", 1)[0] # Remove last word | |
# # Ensure proper formatting | |
# response = response.rstrip("?").rstrip(".") + "." | |
# response = response.replace("...", ".") | |
# Initialize client and make API call | |
# client = initialize_inference_client() | |
client = inference_client | |
# Generate response using Inference API | |
completion = client.chat.completions.create( | |
model="mistralai/Mistral-7B-Instruct-v0.3", # Use inference API model | |
messages=messages, | |
max_tokens=520, | |
) | |
response = completion.choices[0].message.content | |
tprint(f"Generated response Inference API: {response}") | |
if not response: | |
return "You look around carefully." | |
# Safety check the responce using inference API | |
if not is_safe(response): | |
logger.warning("Unsafe content detected - blocking response") | |
return "This response was blocked for safety reasons." | |
# # Perform safety check before returning | |
# safe = is_safe(response) | |
# tprint(f"\nSafety Check Result: {'SAFE' if safe else 'UNSAFE'}") | |
# logger.info(f"Safety check result: {'SAFE' if safe else 'UNSAFE'}") | |
# if not safe: | |
# logging.warning("Unsafe content detected - blocking response") | |
# tprint("Unsafe content detected - Response blocked") | |
# return "This response was blocked for safety reasons." | |
# if safe: | |
# # Check for quest completion | |
# quest_completed, quest_message = check_quest_completion(message, game_state) | |
# if quest_completed: | |
# response += quest_message | |
# # Check for item updates | |
# inventory_update = update_game_inventory(game_state, response) | |
# if inventory_update: | |
# response += inventory_update | |
# Check for quest completion | |
quest_completed, quest_message = check_quest_completion(message, game_state) | |
if quest_completed: | |
response += quest_message | |
# Check for item-inventory updates | |
inventory_update, inventory_data = update_game_inventory(game_state, response) | |
if inventory_update: | |
response += inventory_update | |
tprint(f"Final response in run_action: {response}") | |
# Validate response | |
return response if response else "You look around carefully." | |
except KeyError as e: | |
logger.error(f"Missing required game state key: {e}") | |
return "Error: Game state is missing required information" | |
except Exception as e: | |
logger.error(f"Error generating response: {e}") | |
return ( | |
"I apologize, but I had trouble processing that command. Please try again." | |
) | |
def update_game_status(game_state: Dict) -> Tuple[str, str]: | |
"""Generate updated status and quest display text""" | |
# Status text | |
status_text = ( | |
f"Health: {game_state.get('player').health if game_state.get('player') else 100}/100\n" | |
f"Level: {game_state.get('level', 1)}\n" | |
f"Exp: {game_state.get('exp', 0)}/{100 * game_state.get('level', 1)}" | |
) | |
# Quest text | |
quest_text = "No active quest" | |
if game_state.get("current_quest"): | |
quest = game_state["current_quest"] | |
quest_text = f"{quest['title']}\n{quest['description']}" | |
if quest.get("next_quest_hint"): | |
quest_text += f"\n{quest['next_quest_hint']}" | |
return status_text, quest_text | |
def chat_response(message: str, chat_history: list, current_state: dict) -> tuple: | |
"""Process chat input and return response with updates""" | |
try: | |
if not message.strip(): | |
return chat_history, current_state, "", "", [] # Add empty inventory data | |
# Get AI response | |
output = run_action(message, chat_history, current_state) | |
# Update chat history without status info | |
chat_history = chat_history or [] | |
chat_history.append((message, output)) | |
# Update status displays | |
status_text, quest_text = update_game_status(current_state) | |
# Get inventory updates | |
update_msg, inventory_data = update_game_inventory(current_state, output) | |
if update_msg: | |
output += update_msg | |
# Return tuple includes empty string to clear input | |
return chat_history, current_state, status_text, quest_text, inventory_data | |
except Exception as e: | |
logger.error(f"Error in chat response: {e}") | |
return chat_history, current_state, "", "", [] | |
def start_game(main_loop, game_state, share=False): | |
"""Initialize and launch game interface""" | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# AI Dungeon Adventure") | |
# Game state storage | |
state = gr.State(game_state) | |
history = gr.State([]) | |
with gr.Row(): | |
# Game display | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot( | |
height=550, | |
placeholder="Type 'start game' to begin", | |
) | |
# Input area with submit button | |
with gr.Row(): | |
txt = gr.Textbox( | |
show_label=False, | |
placeholder="What do you want to do?", | |
container=False, | |
) | |
submit_btn = gr.Button("Submit", variant="primary") | |
clear = gr.ClearButton([txt, chatbot]) | |
# Enhanced Status panel | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown("### Character Status") | |
status = gr.Textbox( | |
label="Status", | |
value="Health: 100/100\nLevel: 1\nExp: 0/100", | |
interactive=False, | |
) | |
quest_display = gr.Textbox( | |
label="Current Quest", | |
value="No active quest", | |
interactive=False, | |
) | |
inventory_data = [ | |
[item, count] | |
for item, count in game_state.get("inventory", {}).items() | |
] | |
inventory = gr.Dataframe( | |
value=inventory_data, | |
headers=["Item", "Quantity"], | |
label="Inventory", | |
interactive=False, | |
) | |
# Command suggestions | |
gr.Examples( | |
examples=[ | |
"look around", | |
"continue the story", | |
"take sword", | |
"go to the forest", | |
], | |
inputs=txt, | |
) | |
# def chat_response( | |
# message: str, chat_history: list, current_state: dict | |
# ) -> tuple: | |
# """Process chat input and return response with updates""" | |
# try: | |
# if not message.strip(): | |
# return chat_history, current_state, "" # Only clear input | |
# # Get AI response | |
# output = run_action(message, chat_history, current_state) | |
# # Update chat history | |
# chat_history = chat_history or [] | |
# chat_history.append((message, output)) | |
# # Update status if player exists | |
# # Update displays | |
# status_text = ( | |
# f"Health: {current_state['player'].health}/{current_state['player'].max_health}\n" | |
# f"Level: {current_state['player'].level}\n" | |
# f"Exp: {current_state['player'].exp}/{current_state['player'].exp_to_level}" | |
# ) | |
# quest_text = "No active quest" | |
# if current_state["current_quest"]: | |
# quest = current_state["current_quest"] | |
# quest_text = f"{quest['title']}\n{quest['description']}" | |
# # Update inventory display | |
# inv_data = [ | |
# [item, count] | |
# for item, count in current_state.get("inventory", {}).items() | |
# ] | |
# return (chat_history, current_state, status_text, quest_text) | |
# except Exception as e: | |
# logger.error(f"Error in chat response: {e}") | |
# return ( | |
# chat_history + [(message, "Error processing command")], | |
# current_state, | |
# status.value, | |
# quest_display.value, | |
# ) | |
# # Connect components | |
# submit_btn.click( | |
# chat_response, | |
# inputs=[txt, chatbot, state], | |
# # outputs=[chatbot, state, txt], | |
# outputs=[chatbot, state, status, quest_display], | |
# show_progress=True, | |
# ) | |
# # txt.submit( | |
# # chat_response, | |
# # [txt, chatbot, state], | |
# # [chatbot, state, txt], | |
# # show_progress=True, | |
# # ) | |
# txt.submit( | |
# chat_response, | |
# inputs=[txt, chatbot, state], | |
# outputs=[chatbot, state, status, quest_display], # Keep status separate | |
# show_progress=True, | |
# ) | |
def submit_action(message, history, state): | |
# Process response | |
new_history, new_state, status_text, quest_text, inventory_data = ( | |
chat_response(message, history, state) | |
) | |
# Update inventory display | |
inventory.value = inventory_data | |
# Clear input | |
return "", new_history, new_state, status_text, quest_text, inventory | |
submit_btn.click( | |
submit_action, | |
inputs=[txt, chatbot, state], | |
outputs=[txt, chatbot, state, status, quest_display, inventory], | |
) | |
txt.submit( | |
submit_action, | |
inputs=[txt, chatbot, state], | |
outputs=[txt, chatbot, state, status, quest_display, inventory], | |
) | |
demo.launch(share=share) | |
everyone_content_policy = { | |
"policy": """O1: Violence and Hate. | |
Should not | |
- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics. | |
- Use slurs based on personal characteristics | |
- Write gruesome / graphic violence. | |
Can | |
- Write tame fantasy violence | |
O2: Sexual Content. | |
Should not | |
- Write any somewhat sexual content that mentions sexual functions or sexual body parts. | |
- Write any content that would be inappropriate for kids. | |
Can | |
- Write romantic content (including kissing, flirting etc...) | |
O3: Self-Harm. | |
Should not | |
- Encourage people to harm themselves. | |
- Romanticize or condone self-harm or suicide. | |
- Write story depicting suicide. | |
O4: Profanity. | |
Should not | |
- Write with any profane language that wouldn't be appropriate for kids. | |
Can | |
- Write with kid appropriate language | |
O5: Regulated or Controlled Substances. | |
Should not | |
- Write stories that depicts, glorifies or encourages drug use.""" | |
} | |
def initialize_safety_client(): | |
"""Initialize HuggingFace Inference Client""" | |
try: | |
inference_key = get_huggingface_inference_key() | |
# api_key = get_huggingface_api_key() | |
return InferenceClient(api_key=inference_key) | |
except Exception as e: | |
logger.error(f"Failed to initialize safety client: {e}") | |
raise | |
# Initialize safety model pipeline | |
try: | |
safety_client = initialize_safety_client() | |
except Exception as e: | |
logger.error(f"Failed to initialize model: {str(e)}") | |
# Fallback to CPU if GPU initialization fails | |
def is_safe(message: str) -> bool: | |
"""Check content safety using Inference API""" | |
try: | |
# client = initialize_safety_client() | |
messages = [ | |
{"role": "user", "content": f"Check if this content is safe:\n{message}"}, | |
{ | |
"role": "assistant", | |
"content": f"I will check if the content is safe based on this content policy:\n{everyone_content_policy['policy']}", | |
}, | |
{"role": "user", "content": "Is it safe or unsafe?"}, | |
] | |
try: | |
completion = safety_client.chat.completions.create( | |
model=MODEL_CONFIG["safety_model"]["name"], | |
messages=messages, | |
max_tokens=MODEL_CONFIG["safety_model"]["max_tokens"], | |
temperature=0.1, | |
) | |
response = completion.choices[0].message.content.lower() | |
logger.info(f"Safety check response: {response}") | |
is_safe = "safe" in response and "unsafe" not in response | |
logger.info(f"Safety check result: {'SAFE' if is_safe else 'UNSAFE'}") | |
return is_safe | |
except Exception as api_error: | |
logger.error(f"API error: {api_error}") | |
# Fallback to allow common game commands | |
return any( | |
cmd in message.lower() for cmd in PROMPT_GUARD_CONFIG["safe_commands"] | |
) | |
except Exception as e: | |
logger.error(f"Safety check failed: {e}") | |
return False | |
# def init_safety_model(model_name, force_cpu=False): | |
# """Initialize safety checking model with optimized memory usage""" | |
# try: | |
# if force_cpu: | |
# device = -1 | |
# else: | |
# device = MODEL_CONFIG["safety_model"]["device"] | |
# # model_id = "meta-llama/Llama-Guard-3-8B" | |
# # model_id = "meta-llama/Llama-Guard-3-1B" | |
# api_key = get_huggingface_api_key() | |
# safety_model = AutoModelForCausalLM.from_pretrained( | |
# model_name, | |
# token=api_key, | |
# torch_dtype=MODEL_CONFIG["safety_model"]["dtype"], | |
# use_cache=True, | |
# device_map="auto", | |
# ) | |
# safety_model.config.use_cache = True | |
# safety_tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_key) | |
# # Set pad token explicitly | |
# safety_tokenizer.pad_token = safety_tokenizer.eos_token | |
# logger.info(f"Safety model initialized successfully on {device}") | |
# return safety_model, safety_tokenizer | |
# except Exception as e: | |
# logger.error(f"Failed to initialize safety model: {e}") | |
# raise | |
# # Initialize safety model pipeline | |
# try: | |
# safety_model_name = MODEL_CONFIG["safety_model"]["name"] | |
# api_key = get_huggingface_api_key() | |
# # Initialize the pipeline with memory management | |
# safety_model, safety_tokenizer = init_safety_model(safety_model_name) | |
# except Exception as e: | |
# logger.error(f"Failed to initialize model: {str(e)}") | |
# # Fallback to CPU if GPU initialization fails | |
# try: | |
# logger.info("Attempting CPU fallback...") | |
# safety_model, safety_tokenizer = init_safety_model( | |
# safety_model_name, force_cpu=True | |
# ) | |
# except Exception as e: | |
# logger.error(f"CPU fallback failed: {str(e)}") | |
# raise | |
# @lru_cache(maxsize=1000) | |
# def get_safety_response(prompt: str) -> str: | |
# """Cached safety check responses""" | |
# with torch.no_grad(): | |
# inputs = safety_tokenizer( | |
# prompt, | |
# return_tensors="pt", | |
# padding=True, | |
# truncation=True, | |
# max_length=MODEL_CONFIG["safety_model"]["max_length"], | |
# ).to(safety_model.device) | |
# output = safety_model.generate( | |
# **inputs, | |
# max_new_tokens=5, | |
# temperature=0.1, # Lower temperature for more consistent output | |
# pad_token_id=safety_tokenizer.pad_token_id, | |
# eos_token_id=safety_tokenizer.eos_token_id, | |
# ) | |
# return safety_tokenizer.decode(output[0], skip_special_tokens=True) | |
# def is_safe(message: str) -> bool: | |
# """Check if message content is safe""" | |
# try: | |
# logger.info("Performing safety check...") | |
# # Build the prompt with embedded values | |
# # prompt = f"""[INST] Task: Check if this content is safe according to our safety policy with the below categories. | |
# # <BEGIN UNSAFE CONTENT CATEGORIES> | |
# # {everyone_content_policy['policy']} | |
# # <END UNSAFE CONTENT CATEGORIES> | |
# # <CONTENT TO CHECK> | |
# # {message} | |
# # <END CONTENT TO CHECK> | |
# # Provide your safety assessment for user in the above CONTENT TO CHECK section: | |
# # - Reply ONLY with single word 'safe' or 'unsafe' [/INST] | |
# # """ | |
# prompt = f"""[INST]Check if safe: | |
# {message} | |
# Reply 'safe' or 'unsafe'[/INST] | |
# """ | |
# # Generate safety check on cuda | |
# # with torch.no_grad(): | |
# # inputs = safety_tokenizer( | |
# # prompt, | |
# # return_tensors="pt", | |
# # padding=True, | |
# # truncation=True, | |
# # ) | |
# # # Move inputs to correct device | |
# # inputs = {k: v.to(device) for k, v in inputs.items()} | |
# # output = safety_model.generate( | |
# # **inputs, | |
# # max_new_tokens=10, | |
# # temperature=0.1, # Lower temperature for more consistent output | |
# # pad_token_id=safety_tokenizer.pad_token_id, # Use configured pad token | |
# # eos_token_id=safety_tokenizer.eos_token_id, | |
# # do_sample=False, | |
# # ) | |
# # result = safety_tokenizer.decode(output[0], skip_special_tokens=True) | |
# result = get_safety_response(prompt) | |
# tprint(f"Raw safety check result: {result}") | |
# # # Extract response after prompt | |
# # if "[/INST]" in result: | |
# # result = result.split("[/INST]")[-1] | |
# # # Clean response | |
# # result = result.lower().strip() | |
# # tprint(f"Cleaned safety check result: {result}") | |
# # words = [word for word in result.split() if word in ["safe", "unsafe"]] | |
# # # Take first valid response word | |
# # is_safe = words[0] == "safe" if words else False | |
# # tprint("Final Safety check result:", is_safe) | |
# is_safe = "safe" in result.lower().split() | |
# logger.info( | |
# f"Safety check completed - Result: {'SAFE' if is_safe else 'UNSAFE'}" | |
# ) | |
# return is_safe | |
# except Exception as e: | |
# logger.error(f"Safety check failed: {e}") | |
# return False | |
# def detect_inventory_changes(game_state, output): | |
# inventory = game_state["inventory"] | |
# messages = [ | |
# {"role": "system", "content": system_prompt}, | |
# {"role": "user", "content": f"Current Inventory: {str(inventory)}"}, | |
# {"role": "user", "content": f"Recent Story: {output}"}, | |
# {"role": "user", "content": "Inventory Updates"}, | |
# ] | |
# input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) | |
# model_output = generator(input_text, num_return_sequences=1, temperature=0.0) | |
# response = model_output[0]["generated_text"] | |
# result = json.loads(response) | |
# return result["itemUpdates"] | |
# def update_inventory(inventory, item_updates): | |
# update_msg = "" | |
# for update in item_updates: | |
# name = update["name"] | |
# change_amount = update["change_amount"] | |
# if change_amount > 0: | |
# if name not in inventory: | |
# inventory[name] = change_amount | |
# else: | |
# inventory[name] += change_amount | |
# update_msg += f"\nInventory: {name} +{change_amount}" | |
# elif name in inventory and change_amount < 0: | |
# inventory[name] += change_amount | |
# update_msg += f"\nInventory: {name} {change_amount}" | |
# if name in inventory and inventory[name] < 0: | |
# del inventory[name] | |
# return update_msg | |
logging.info("Finished helper function") | |