prgrmc commited on
Commit
ebb4490
·
1 Parent(s): fd206ac

add login function from huggingface_hub

Browse files
Files changed (1) hide show
  1. helper.py +14 -2
helper.py CHANGED
@@ -9,6 +9,8 @@ from transformers import AutoModelForSequenceClassification
9
  from huggingface_hub import InferenceClient
10
 
11
  from transformers import pipeline
 
 
12
  from transformers import AutoTokenizer, AutoModelForCausalLM
13
  import logging
14
  import psutil
@@ -39,6 +41,7 @@ def load_env():
39
  def get_huggingface_api_key():
40
  load_env()
41
  huggingface_api_key = os.getenv("HUGGINGFACE_API_KEY")
 
42
  if not huggingface_api_key:
43
  logging.error("HUGGINGFACE_API_KEY not found in environment variables")
44
  raise ValueError("HUGGINGFACE_API_KEY not found in environment variables")
@@ -1131,10 +1134,19 @@ def initialize_safety_client():
1131
  raise
1132
 
1133
 
 
 
 
 
 
 
 
 
 
1134
  def is_safe(message: str) -> bool:
1135
  """Check content safety using Inference API"""
1136
  try:
1137
- client = initialize_safety_client()
1138
 
1139
  messages = [
1140
  {"role": "user", "content": f"Check if this content is safe:\n{message}"},
@@ -1146,7 +1158,7 @@ def is_safe(message: str) -> bool:
1146
  ]
1147
 
1148
  try:
1149
- completion = client.chat.completions.create(
1150
  model=MODEL_CONFIG["safety_model"]["name"],
1151
  messages=messages,
1152
  max_tokens=MODEL_CONFIG["safety_model"]["max_tokens"],
 
9
  from huggingface_hub import InferenceClient
10
 
11
  from transformers import pipeline
12
+ from huggingface_hub import login
13
+
14
  from transformers import AutoTokenizer, AutoModelForCausalLM
15
  import logging
16
  import psutil
 
41
  def get_huggingface_api_key():
42
  load_env()
43
  huggingface_api_key = os.getenv("HUGGINGFACE_API_KEY")
44
+ login(token=huggingface_api_key)
45
  if not huggingface_api_key:
46
  logging.error("HUGGINGFACE_API_KEY not found in environment variables")
47
  raise ValueError("HUGGINGFACE_API_KEY not found in environment variables")
 
1134
  raise
1135
 
1136
 
1137
+ # Initialize safety model pipeline
1138
+ try:
1139
+ safety_client = initialize_safety_client()
1140
+
1141
+ except Exception as e:
1142
+ logger.error(f"Failed to initialize model: {str(e)}")
1143
+ # Fallback to CPU if GPU initialization fails
1144
+
1145
+
1146
  def is_safe(message: str) -> bool:
1147
  """Check content safety using Inference API"""
1148
  try:
1149
+ # client = initialize_safety_client()
1150
 
1151
  messages = [
1152
  {"role": "user", "content": f"Check if this content is safe:\n{message}"},
 
1158
  ]
1159
 
1160
  try:
1161
+ completion = safety_client.chat.completions.create(
1162
  model=MODEL_CONFIG["safety_model"]["name"],
1163
  messages=messages,
1164
  max_tokens=MODEL_CONFIG["safety_model"]["max_tokens"],