Spaces:
Running
Running
add interence API mistral
Browse files
README.md
CHANGED
@@ -7,7 +7,7 @@ sdk: gradio
|
|
7 |
sdk_version: 5.9.1
|
8 |
app_file: main.py
|
9 |
pinned: false
|
10 |
-
license:
|
11 |
---
|
12 |
|
13 |
# AI-Powered Dungeon Adventure Game
|
|
|
7 |
sdk_version: 5.9.1
|
8 |
app_file: main.py
|
9 |
pinned: false
|
10 |
+
license: -
|
11 |
---
|
12 |
|
13 |
# AI-Powered Dungeon Adventure Game
|
helper.py
CHANGED
@@ -5,6 +5,7 @@ import json
|
|
5 |
import gradio as gr
|
6 |
import torch # first import torch then transformers
|
7 |
|
|
|
8 |
from transformers import pipeline
|
9 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
import logging
|
@@ -42,11 +43,23 @@ def get_huggingface_api_key():
|
|
42 |
return huggingface_api_key
|
43 |
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
# Model configuration
|
46 |
MODEL_CONFIG = {
|
47 |
"main_model": {
|
48 |
# "name": "meta-llama/Llama-3.2-3B-Instruct",
|
49 |
-
"name": "meta-llama/Llama-3.2-1B-Instruct", # to fit in cpu on hugging face space
|
|
|
|
|
|
|
50 |
# "dtype": torch.bfloat16,
|
51 |
"dtype": torch.float32, # Use float32 for CPU
|
52 |
"max_length": 512,
|
@@ -110,31 +123,44 @@ def initialize_model_pipeline(model_name, force_cpu=False):
|
|
110 |
raise
|
111 |
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
# Use a smaller model for testing
|
116 |
-
# model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
|
117 |
-
# model_name = "google/gemma-2-2b" # Start with a smaller model
|
118 |
-
# model_name = "microsoft/phi-2"
|
119 |
-
# model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
120 |
-
# model_name = "meta-llama/Llama-3.2-3B-Instruct"
|
121 |
-
|
122 |
-
model_name = MODEL_CONFIG["main_model"]["name"]
|
123 |
-
|
124 |
-
# Initialize the pipeline with memory management
|
125 |
-
generator, tokenizer = initialize_model_pipeline(model_name)
|
126 |
-
|
127 |
-
except Exception as e:
|
128 |
-
logger.error(f"Failed to initialize model: {str(e)}")
|
129 |
-
# Fallback to CPU if GPU initialization fails
|
130 |
try:
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
133 |
except Exception as e:
|
134 |
-
logger.error(f"
|
135 |
raise
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
def load_world(filename):
|
139 |
with open(filename, "r") as f:
|
140 |
return json.load(f)
|
@@ -494,12 +520,12 @@ def extract_response_after_action(full_text: str, action: str) -> str:
|
|
494 |
def run_action(message: str, history: list, game_state: Dict) -> str:
|
495 |
"""Process game actions and generate responses with quest handling"""
|
496 |
try:
|
|
|
|
|
|
|
497 |
# Handle start game command
|
498 |
if message.lower() == "start game":
|
499 |
|
500 |
-
initial_quest = generate_next_quest(game_state)
|
501 |
-
game_state["current_quest"] = initial_quest
|
502 |
-
|
503 |
start_response = f"""Welcome to {game_state['name']}. {game_state['world']}
|
504 |
|
505 |
{game_state['start']}
|
@@ -538,71 +564,105 @@ Inventory: {json.dumps(game_state['inventory'])}"""
|
|
538 |
# - Describe immediate surroundings and sensations
|
539 |
# - Keep responses focused on the player's direct experience"""
|
540 |
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
messages = [
|
542 |
{"role": "system", "content": system_prompt},
|
543 |
{"role": "user", "content": world_info},
|
|
|
|
|
|
|
|
|
|
|
544 |
]
|
545 |
|
546 |
-
# Format chat history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
547 |
if history:
|
548 |
-
for h in history:
|
549 |
if isinstance(h, tuple):
|
550 |
-
messages.append({"role": "
|
551 |
-
messages.append({"role": "
|
552 |
|
553 |
-
messages.append({"role": "user", "content": message})
|
554 |
|
555 |
# Convert messages to string format for pipeline
|
556 |
prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
557 |
|
558 |
logger.info("Generating response...")
|
559 |
-
|
560 |
-
model_output = generator(
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
)
|
569 |
-
# logger.info(f"Raw model output: {model_output}")
|
570 |
|
571 |
-
# Check for None response
|
572 |
-
if not model_output or not isinstance(model_output, list):
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
|
577 |
-
if not model_output[0] or not isinstance(model_output[0], dict):
|
578 |
-
|
579 |
-
|
580 |
|
581 |
-
# Extract and clean response
|
582 |
-
full_response = model_output[0]["generated_text"]
|
583 |
-
if not full_response:
|
584 |
-
|
585 |
-
|
|
|
|
|
|
|
|
|
|
|
586 |
|
587 |
-
|
|
|
588 |
|
589 |
-
|
590 |
-
|
|
|
591 |
|
592 |
-
#
|
593 |
-
|
|
|
594 |
|
595 |
-
# #
|
596 |
-
#
|
597 |
-
#
|
598 |
|
599 |
-
#
|
600 |
-
|
601 |
-
response = response.rsplit(" ", 1)[0] # Remove last word
|
602 |
|
603 |
-
#
|
604 |
-
|
605 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
606 |
|
607 |
# # Perform safety check before returning
|
608 |
# safe = is_safe(response)
|
@@ -635,6 +695,7 @@ Inventory: {json.dumps(game_state['inventory'])}"""
|
|
635 |
if inventory_update:
|
636 |
response += inventory_update
|
637 |
|
|
|
638 |
# Validate response
|
639 |
return response if response else "You look around carefully."
|
640 |
|
|
|
5 |
import gradio as gr
|
6 |
import torch # first import torch then transformers
|
7 |
|
8 |
+
from huggingface_hub import InferenceClient
|
9 |
from transformers import pipeline
|
10 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
11 |
import logging
|
|
|
43 |
return huggingface_api_key
|
44 |
|
45 |
|
46 |
+
def get_huggingface_inference_key():
|
47 |
+
load_env()
|
48 |
+
huggingface_inference_key = os.getenv("HUGGINGFACE_INFERENCE_KEY")
|
49 |
+
if not huggingface_inference_key:
|
50 |
+
logging.error("HUGGINGFACE_API_KEY not found in environment variables")
|
51 |
+
raise ValueError("HUGGINGFACE_API_KEY not found in environment variables")
|
52 |
+
return huggingface_inference_key
|
53 |
+
|
54 |
+
|
55 |
# Model configuration
|
56 |
MODEL_CONFIG = {
|
57 |
"main_model": {
|
58 |
# "name": "meta-llama/Llama-3.2-3B-Instruct",
|
59 |
+
# "name": "meta-llama/Llama-3.2-1B-Instruct", # to fit in cpu on hugging face space
|
60 |
+
"name": "meta-llama/Llama-3.2-1B", # to fit in cpu on hugging face space
|
61 |
+
# "name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # to fit in cpu on hugging face space
|
62 |
+
# "name": "microsoft/phi-2",
|
63 |
# "dtype": torch.bfloat16,
|
64 |
"dtype": torch.float32, # Use float32 for CPU
|
65 |
"max_length": 512,
|
|
|
123 |
raise
|
124 |
|
125 |
|
126 |
+
def initialize_inference_client():
|
127 |
+
"""Initialize HuggingFace Inference Client"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
try:
|
129 |
+
inference_key = get_huggingface_inference_key()
|
130 |
+
|
131 |
+
client = InferenceClient(api_key=inference_key)
|
132 |
+
logger.info("Inference Client initialized successfully")
|
133 |
+
return client
|
134 |
except Exception as e:
|
135 |
+
logger.error(f"Failed to initialize Inference Client: {e}")
|
136 |
raise
|
137 |
|
138 |
|
139 |
+
# # Initialize model pipeline
|
140 |
+
# try:
|
141 |
+
# # Use a smaller model for testing
|
142 |
+
# # model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
|
143 |
+
# # model_name = "google/gemma-2-2b" # Start with a smaller model
|
144 |
+
# # model_name = "microsoft/phi-2"
|
145 |
+
# # model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
146 |
+
# # model_name = "meta-llama/Llama-3.2-3B-Instruct"
|
147 |
+
|
148 |
+
# model_name = MODEL_CONFIG["main_model"]["name"]
|
149 |
+
|
150 |
+
# # Initialize the pipeline with memory management
|
151 |
+
# generator, tokenizer = initialize_model_pipeline(model_name)
|
152 |
+
|
153 |
+
# except Exception as e:
|
154 |
+
# logger.error(f"Failed to initialize model: {str(e)}")
|
155 |
+
# # Fallback to CPU if GPU initialization fails
|
156 |
+
# try:
|
157 |
+
# logger.info("Attempting CPU fallback...")
|
158 |
+
# generator, tokenizer = initialize_model_pipeline(model_name, force_cpu=True)
|
159 |
+
# except Exception as e:
|
160 |
+
# logger.error(f"CPU fallback failed: {str(e)}")
|
161 |
+
# raise
|
162 |
+
|
163 |
+
|
164 |
def load_world(filename):
|
165 |
with open(filename, "r") as f:
|
166 |
return json.load(f)
|
|
|
520 |
def run_action(message: str, history: list, game_state: Dict) -> str:
|
521 |
"""Process game actions and generate responses with quest handling"""
|
522 |
try:
|
523 |
+
initial_quest = generate_next_quest(game_state)
|
524 |
+
game_state["current_quest"] = initial_quest
|
525 |
+
|
526 |
# Handle start game command
|
527 |
if message.lower() == "start game":
|
528 |
|
|
|
|
|
|
|
529 |
start_response = f"""Welcome to {game_state['name']}. {game_state['world']}
|
530 |
|
531 |
{game_state['start']}
|
|
|
564 |
# - Describe immediate surroundings and sensations
|
565 |
# - Keep responses focused on the player's direct experience"""
|
566 |
|
567 |
+
# messages = [
|
568 |
+
# {"role": "system", "content": system_prompt},
|
569 |
+
# {"role": "user", "content": world_info},
|
570 |
+
# ]
|
571 |
+
|
572 |
+
# Properly formatted messages for API
|
573 |
messages = [
|
574 |
{"role": "system", "content": system_prompt},
|
575 |
{"role": "user", "content": world_info},
|
576 |
+
{
|
577 |
+
"role": "assistant",
|
578 |
+
"content": "I understand the game world and will help guide your adventure.",
|
579 |
+
},
|
580 |
+
{"role": "user", "content": message},
|
581 |
]
|
582 |
|
583 |
+
# # Format chat history
|
584 |
+
# if history:
|
585 |
+
# for h in history:
|
586 |
+
# if isinstance(h, tuple):
|
587 |
+
# messages.append({"role": "assistant", "content": h[0]})
|
588 |
+
# messages.append({"role": "user", "content": h[1]})
|
589 |
+
|
590 |
+
# Add history in correct alternating format
|
591 |
if history:
|
592 |
+
for h in history[-3:]: # Last 3 exchanges
|
593 |
if isinstance(h, tuple):
|
594 |
+
messages.append({"role": "user", "content": h[0]})
|
595 |
+
messages.append({"role": "assistant", "content": h[1]})
|
596 |
|
597 |
+
# messages.append({"role": "user", "content": message})
|
598 |
|
599 |
# Convert messages to string format for pipeline
|
600 |
prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
601 |
|
602 |
logger.info("Generating response...")
|
603 |
+
## Generate response
|
604 |
+
# model_output = generator(
|
605 |
+
# prompt,
|
606 |
+
# max_new_tokens=len(tokenizer.encode(message))
|
607 |
+
# + 120, # Set max_new_tokens based on input length
|
608 |
+
# num_return_sequences=1,
|
609 |
+
# # temperature=0.7, # More creative but still focused
|
610 |
+
# repetition_penalty=1.2,
|
611 |
+
# pad_token_id=tokenizer.eos_token_id,
|
612 |
+
# )
|
|
|
613 |
|
614 |
+
# # Check for None response
|
615 |
+
# if not model_output or not isinstance(model_output, list):
|
616 |
+
# logger.error(f"Invalid model output: {model_output}")
|
617 |
+
# print(f"Invalid model output: {model_output}")
|
618 |
+
# return "You look around carefully."
|
619 |
|
620 |
+
# if not model_output[0] or not isinstance(model_output[0], dict):
|
621 |
+
# logger.error(f"Invalid response format: {type(model_output[0])}")
|
622 |
+
# return "You look around carefully."
|
623 |
|
624 |
+
# # Extract and clean response
|
625 |
+
# full_response = model_output[0]["generated_text"]
|
626 |
+
# if not full_response:
|
627 |
+
# logger.error("Empty response from model")
|
628 |
+
# return "You look around carefully."
|
629 |
+
|
630 |
+
# print(f"Full response in run_action: {full_response}")
|
631 |
+
|
632 |
+
# response = extract_response_after_action(full_response, message)
|
633 |
+
# print(f"Extracted response in run_action: {response}")
|
634 |
|
635 |
+
# # Convert to second person
|
636 |
+
# response = response.replace("Elara", "You")
|
637 |
|
638 |
+
# # # Format response
|
639 |
+
# # if not response.startswith("You"):
|
640 |
+
# # response = "You see " + response
|
641 |
|
642 |
+
# # Validate no cut-off sentences
|
643 |
+
# if response.rstrip().endswith(("you also", "meanwhile", "suddenly", "...")):
|
644 |
+
# response = response.rsplit(" ", 1)[0] # Remove last word
|
645 |
|
646 |
+
# # Ensure proper formatting
|
647 |
+
# response = response.rstrip("?").rstrip(".") + "."
|
648 |
+
# response = response.replace("...", ".")
|
649 |
|
650 |
+
# Initialize client and make API call
|
651 |
+
client = initialize_inference_client()
|
|
|
652 |
|
653 |
+
# Generate response using Inference API
|
654 |
+
completion = client.chat.completions.create(
|
655 |
+
model="mistralai/Mistral-7B-Instruct-v0.3", # Use inference API model
|
656 |
+
messages=messages,
|
657 |
+
max_tokens=520,
|
658 |
+
)
|
659 |
+
|
660 |
+
response = completion.choices[0].message.content
|
661 |
+
|
662 |
+
print(f"Generated response Inference API: {response}")
|
663 |
+
|
664 |
+
if not response:
|
665 |
+
return "You look around carefully."
|
666 |
|
667 |
# # Perform safety check before returning
|
668 |
# safe = is_safe(response)
|
|
|
695 |
if inventory_update:
|
696 |
response += inventory_update
|
697 |
|
698 |
+
print(f"Final response in run_action: {response}")
|
699 |
# Validate response
|
700 |
return response if response else "You look around carefully."
|
701 |
|