Spaces:
Running
Running
add login function from huggingface_hub
Browse files
helper.py
CHANGED
@@ -9,6 +9,8 @@ from transformers import AutoModelForSequenceClassification
|
|
9 |
from huggingface_hub import InferenceClient
|
10 |
|
11 |
from transformers import pipeline
|
|
|
|
|
12 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
13 |
import logging
|
14 |
import psutil
|
@@ -39,6 +41,7 @@ def load_env():
|
|
39 |
def get_huggingface_api_key():
|
40 |
load_env()
|
41 |
huggingface_api_key = os.getenv("HUGGINGFACE_API_KEY")
|
|
|
42 |
if not huggingface_api_key:
|
43 |
logging.error("HUGGINGFACE_API_KEY not found in environment variables")
|
44 |
raise ValueError("HUGGINGFACE_API_KEY not found in environment variables")
|
@@ -1131,10 +1134,19 @@ def initialize_safety_client():
|
|
1131 |
raise
|
1132 |
|
1133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1134 |
def is_safe(message: str) -> bool:
|
1135 |
"""Check content safety using Inference API"""
|
1136 |
try:
|
1137 |
-
client = initialize_safety_client()
|
1138 |
|
1139 |
messages = [
|
1140 |
{"role": "user", "content": f"Check if this content is safe:\n{message}"},
|
@@ -1146,7 +1158,7 @@ def is_safe(message: str) -> bool:
|
|
1146 |
]
|
1147 |
|
1148 |
try:
|
1149 |
-
completion =
|
1150 |
model=MODEL_CONFIG["safety_model"]["name"],
|
1151 |
messages=messages,
|
1152 |
max_tokens=MODEL_CONFIG["safety_model"]["max_tokens"],
|
|
|
9 |
from huggingface_hub import InferenceClient
|
10 |
|
11 |
from transformers import pipeline
|
12 |
+
from huggingface_hub import login
|
13 |
+
|
14 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
15 |
import logging
|
16 |
import psutil
|
|
|
41 |
def get_huggingface_api_key():
|
42 |
load_env()
|
43 |
huggingface_api_key = os.getenv("HUGGINGFACE_API_KEY")
|
44 |
+
login(token=huggingface_api_key)
|
45 |
if not huggingface_api_key:
|
46 |
logging.error("HUGGINGFACE_API_KEY not found in environment variables")
|
47 |
raise ValueError("HUGGINGFACE_API_KEY not found in environment variables")
|
|
|
1134 |
raise
|
1135 |
|
1136 |
|
1137 |
+
# Initialize safety model pipeline
|
1138 |
+
try:
|
1139 |
+
safety_client = initialize_safety_client()
|
1140 |
+
|
1141 |
+
except Exception as e:
|
1142 |
+
logger.error(f"Failed to initialize model: {str(e)}")
|
1143 |
+
# Fallback to CPU if GPU initialization fails
|
1144 |
+
|
1145 |
+
|
1146 |
def is_safe(message: str) -> bool:
|
1147 |
"""Check content safety using Inference API"""
|
1148 |
try:
|
1149 |
+
# client = initialize_safety_client()
|
1150 |
|
1151 |
messages = [
|
1152 |
{"role": "user", "content": f"Check if this content is safe:\n{message}"},
|
|
|
1158 |
]
|
1159 |
|
1160 |
try:
|
1161 |
+
completion = safety_client.chat.completions.create(
|
1162 |
model=MODEL_CONFIG["safety_model"]["name"],
|
1163 |
messages=messages,
|
1164 |
max_tokens=MODEL_CONFIG["safety_model"]["max_tokens"],
|