prgrmc commited on
Commit
31e50da
·
1 Parent(s): 82596d2

add interence API mistral

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. helper.py +128 -67
README.md CHANGED
@@ -7,7 +7,7 @@ sdk: gradio
7
  sdk_version: 5.9.1
8
  app_file: main.py
9
  pinned: false
10
- license: mit
11
  ---
12
 
13
  # AI-Powered Dungeon Adventure Game
 
7
  sdk_version: 5.9.1
8
  app_file: main.py
9
  pinned: false
10
+ license: -
11
  ---
12
 
13
  # AI-Powered Dungeon Adventure Game
helper.py CHANGED
@@ -5,6 +5,7 @@ import json
5
  import gradio as gr
6
  import torch # first import torch then transformers
7
 
 
8
  from transformers import pipeline
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
  import logging
@@ -42,11 +43,23 @@ def get_huggingface_api_key():
42
  return huggingface_api_key
43
 
44
 
 
 
 
 
 
 
 
 
 
45
  # Model configuration
46
  MODEL_CONFIG = {
47
  "main_model": {
48
  # "name": "meta-llama/Llama-3.2-3B-Instruct",
49
- "name": "meta-llama/Llama-3.2-1B-Instruct", # to fit in cpu on hugging face space
 
 
 
50
  # "dtype": torch.bfloat16,
51
  "dtype": torch.float32, # Use float32 for CPU
52
  "max_length": 512,
@@ -110,31 +123,44 @@ def initialize_model_pipeline(model_name, force_cpu=False):
110
  raise
111
 
112
 
113
- # Initialize model pipeline
114
- try:
115
- # Use a smaller model for testing
116
- # model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
117
- # model_name = "google/gemma-2-2b" # Start with a smaller model
118
- # model_name = "microsoft/phi-2"
119
- # model_name = "meta-llama/Llama-3.2-1B-Instruct"
120
- # model_name = "meta-llama/Llama-3.2-3B-Instruct"
121
-
122
- model_name = MODEL_CONFIG["main_model"]["name"]
123
-
124
- # Initialize the pipeline with memory management
125
- generator, tokenizer = initialize_model_pipeline(model_name)
126
-
127
- except Exception as e:
128
- logger.error(f"Failed to initialize model: {str(e)}")
129
- # Fallback to CPU if GPU initialization fails
130
  try:
131
- logger.info("Attempting CPU fallback...")
132
- generator, tokenizer = initialize_model_pipeline(model_name, force_cpu=True)
 
 
 
133
  except Exception as e:
134
- logger.error(f"CPU fallback failed: {str(e)}")
135
  raise
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def load_world(filename):
139
  with open(filename, "r") as f:
140
  return json.load(f)
@@ -494,12 +520,12 @@ def extract_response_after_action(full_text: str, action: str) -> str:
494
  def run_action(message: str, history: list, game_state: Dict) -> str:
495
  """Process game actions and generate responses with quest handling"""
496
  try:
 
 
 
497
  # Handle start game command
498
  if message.lower() == "start game":
499
 
500
- initial_quest = generate_next_quest(game_state)
501
- game_state["current_quest"] = initial_quest
502
-
503
  start_response = f"""Welcome to {game_state['name']}. {game_state['world']}
504
 
505
  {game_state['start']}
@@ -538,71 +564,105 @@ Inventory: {json.dumps(game_state['inventory'])}"""
538
  # - Describe immediate surroundings and sensations
539
  # - Keep responses focused on the player's direct experience"""
540
 
 
 
 
 
 
 
541
  messages = [
542
  {"role": "system", "content": system_prompt},
543
  {"role": "user", "content": world_info},
 
 
 
 
 
544
  ]
545
 
546
- # Format chat history
 
 
 
 
 
 
 
547
  if history:
548
- for h in history:
549
  if isinstance(h, tuple):
550
- messages.append({"role": "assistant", "content": h[0]})
551
- messages.append({"role": "user", "content": h[1]})
552
 
553
- messages.append({"role": "user", "content": message})
554
 
555
  # Convert messages to string format for pipeline
556
  prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
557
 
558
  logger.info("Generating response...")
559
- # Generate response
560
- model_output = generator(
561
- prompt,
562
- max_new_tokens=len(tokenizer.encode(message))
563
- + 120, # Set max_new_tokens based on input length
564
- num_return_sequences=1,
565
- # temperature=0.7, # More creative but still focused
566
- repetition_penalty=1.2,
567
- pad_token_id=tokenizer.eos_token_id,
568
- )
569
- # logger.info(f"Raw model output: {model_output}")
570
 
571
- # Check for None response
572
- if not model_output or not isinstance(model_output, list):
573
- logger.error(f"Invalid model output: {model_output}")
574
- print(f"Invalid model output: {model_output}")
575
- return "You look around carefully."
576
 
577
- if not model_output[0] or not isinstance(model_output[0], dict):
578
- logger.error(f"Invalid response format: {type(model_output[0])}")
579
- return "You look around carefully."
580
 
581
- # Extract and clean response
582
- full_response = model_output[0]["generated_text"]
583
- if not full_response:
584
- logger.error("Empty response from model")
585
- return "You look around carefully."
 
 
 
 
 
586
 
587
- print(f"Full response in run_action: {full_response}")
 
588
 
589
- response = extract_response_after_action(full_response, message)
590
- print(f"Extracted response in run_action: {response}")
 
591
 
592
- # Convert to second person
593
- response = response.replace("Elara", "You")
 
594
 
595
- # # Format response
596
- # if not response.startswith("You"):
597
- # response = "You see " + response
598
 
599
- # Validate no cut-off sentences
600
- if response.rstrip().endswith(("you also", "meanwhile", "suddenly", "...")):
601
- response = response.rsplit(" ", 1)[0] # Remove last word
602
 
603
- # Ensure proper formatting
604
- response = response.rstrip("?").rstrip(".") + "."
605
- response = response.replace("...", ".")
 
 
 
 
 
 
 
 
 
 
606
 
607
  # # Perform safety check before returning
608
  # safe = is_safe(response)
@@ -635,6 +695,7 @@ Inventory: {json.dumps(game_state['inventory'])}"""
635
  if inventory_update:
636
  response += inventory_update
637
 
 
638
  # Validate response
639
  return response if response else "You look around carefully."
640
 
 
5
  import gradio as gr
6
  import torch # first import torch then transformers
7
 
8
+ from huggingface_hub import InferenceClient
9
  from transformers import pipeline
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
  import logging
 
43
  return huggingface_api_key
44
 
45
 
46
+ def get_huggingface_inference_key():
47
+ load_env()
48
+ huggingface_inference_key = os.getenv("HUGGINGFACE_INFERENCE_KEY")
49
+ if not huggingface_inference_key:
50
+ logging.error("HUGGINGFACE_API_KEY not found in environment variables")
51
+ raise ValueError("HUGGINGFACE_API_KEY not found in environment variables")
52
+ return huggingface_inference_key
53
+
54
+
55
  # Model configuration
56
  MODEL_CONFIG = {
57
  "main_model": {
58
  # "name": "meta-llama/Llama-3.2-3B-Instruct",
59
+ # "name": "meta-llama/Llama-3.2-1B-Instruct", # to fit in cpu on hugging face space
60
+ "name": "meta-llama/Llama-3.2-1B", # to fit in cpu on hugging face space
61
+ # "name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # to fit in cpu on hugging face space
62
+ # "name": "microsoft/phi-2",
63
  # "dtype": torch.bfloat16,
64
  "dtype": torch.float32, # Use float32 for CPU
65
  "max_length": 512,
 
123
  raise
124
 
125
 
126
+ def initialize_inference_client():
127
+ """Initialize HuggingFace Inference Client"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  try:
129
+ inference_key = get_huggingface_inference_key()
130
+
131
+ client = InferenceClient(api_key=inference_key)
132
+ logger.info("Inference Client initialized successfully")
133
+ return client
134
  except Exception as e:
135
+ logger.error(f"Failed to initialize Inference Client: {e}")
136
  raise
137
 
138
 
139
+ # # Initialize model pipeline
140
+ # try:
141
+ # # Use a smaller model for testing
142
+ # # model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
143
+ # # model_name = "google/gemma-2-2b" # Start with a smaller model
144
+ # # model_name = "microsoft/phi-2"
145
+ # # model_name = "meta-llama/Llama-3.2-1B-Instruct"
146
+ # # model_name = "meta-llama/Llama-3.2-3B-Instruct"
147
+
148
+ # model_name = MODEL_CONFIG["main_model"]["name"]
149
+
150
+ # # Initialize the pipeline with memory management
151
+ # generator, tokenizer = initialize_model_pipeline(model_name)
152
+
153
+ # except Exception as e:
154
+ # logger.error(f"Failed to initialize model: {str(e)}")
155
+ # # Fallback to CPU if GPU initialization fails
156
+ # try:
157
+ # logger.info("Attempting CPU fallback...")
158
+ # generator, tokenizer = initialize_model_pipeline(model_name, force_cpu=True)
159
+ # except Exception as e:
160
+ # logger.error(f"CPU fallback failed: {str(e)}")
161
+ # raise
162
+
163
+
164
  def load_world(filename):
165
  with open(filename, "r") as f:
166
  return json.load(f)
 
520
  def run_action(message: str, history: list, game_state: Dict) -> str:
521
  """Process game actions and generate responses with quest handling"""
522
  try:
523
+ initial_quest = generate_next_quest(game_state)
524
+ game_state["current_quest"] = initial_quest
525
+
526
  # Handle start game command
527
  if message.lower() == "start game":
528
 
 
 
 
529
  start_response = f"""Welcome to {game_state['name']}. {game_state['world']}
530
 
531
  {game_state['start']}
 
564
  # - Describe immediate surroundings and sensations
565
  # - Keep responses focused on the player's direct experience"""
566
 
567
+ # messages = [
568
+ # {"role": "system", "content": system_prompt},
569
+ # {"role": "user", "content": world_info},
570
+ # ]
571
+
572
+ # Properly formatted messages for API
573
  messages = [
574
  {"role": "system", "content": system_prompt},
575
  {"role": "user", "content": world_info},
576
+ {
577
+ "role": "assistant",
578
+ "content": "I understand the game world and will help guide your adventure.",
579
+ },
580
+ {"role": "user", "content": message},
581
  ]
582
 
583
+ # # Format chat history
584
+ # if history:
585
+ # for h in history:
586
+ # if isinstance(h, tuple):
587
+ # messages.append({"role": "assistant", "content": h[0]})
588
+ # messages.append({"role": "user", "content": h[1]})
589
+
590
+ # Add history in correct alternating format
591
  if history:
592
+ for h in history[-3:]: # Last 3 exchanges
593
  if isinstance(h, tuple):
594
+ messages.append({"role": "user", "content": h[0]})
595
+ messages.append({"role": "assistant", "content": h[1]})
596
 
597
+ # messages.append({"role": "user", "content": message})
598
 
599
  # Convert messages to string format for pipeline
600
  prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
601
 
602
  logger.info("Generating response...")
603
+ ## Generate response
604
+ # model_output = generator(
605
+ # prompt,
606
+ # max_new_tokens=len(tokenizer.encode(message))
607
+ # + 120, # Set max_new_tokens based on input length
608
+ # num_return_sequences=1,
609
+ # # temperature=0.7, # More creative but still focused
610
+ # repetition_penalty=1.2,
611
+ # pad_token_id=tokenizer.eos_token_id,
612
+ # )
 
613
 
614
+ # # Check for None response
615
+ # if not model_output or not isinstance(model_output, list):
616
+ # logger.error(f"Invalid model output: {model_output}")
617
+ # print(f"Invalid model output: {model_output}")
618
+ # return "You look around carefully."
619
 
620
+ # if not model_output[0] or not isinstance(model_output[0], dict):
621
+ # logger.error(f"Invalid response format: {type(model_output[0])}")
622
+ # return "You look around carefully."
623
 
624
+ # # Extract and clean response
625
+ # full_response = model_output[0]["generated_text"]
626
+ # if not full_response:
627
+ # logger.error("Empty response from model")
628
+ # return "You look around carefully."
629
+
630
+ # print(f"Full response in run_action: {full_response}")
631
+
632
+ # response = extract_response_after_action(full_response, message)
633
+ # print(f"Extracted response in run_action: {response}")
634
 
635
+ # # Convert to second person
636
+ # response = response.replace("Elara", "You")
637
 
638
+ # # # Format response
639
+ # # if not response.startswith("You"):
640
+ # # response = "You see " + response
641
 
642
+ # # Validate no cut-off sentences
643
+ # if response.rstrip().endswith(("you also", "meanwhile", "suddenly", "...")):
644
+ # response = response.rsplit(" ", 1)[0] # Remove last word
645
 
646
+ # # Ensure proper formatting
647
+ # response = response.rstrip("?").rstrip(".") + "."
648
+ # response = response.replace("...", ".")
649
 
650
+ # Initialize client and make API call
651
+ client = initialize_inference_client()
 
652
 
653
+ # Generate response using Inference API
654
+ completion = client.chat.completions.create(
655
+ model="mistralai/Mistral-7B-Instruct-v0.3", # Use inference API model
656
+ messages=messages,
657
+ max_tokens=520,
658
+ )
659
+
660
+ response = completion.choices[0].message.content
661
+
662
+ print(f"Generated response Inference API: {response}")
663
+
664
+ if not response:
665
+ return "You look around carefully."
666
 
667
  # # Perform safety check before returning
668
  # safe = is_safe(response)
 
695
  if inventory_update:
696
  response += inventory_update
697
 
698
+ print(f"Final response in run_action: {response}")
699
  # Validate response
700
  return response if response else "You look around carefully."
701