Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,14 @@ class TextDetectionApp:
|
|
12 |
self.roberta_tokenizer = AutoTokenizer.from_pretrained("zeyadusf/roberta-DAIGT-kaggle")
|
13 |
self.roberta_model = AutoModelForSequenceClassification.from_pretrained("zeyadusf/roberta-DAIGT-kaggle")
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
# Load Feedforward model
|
16 |
self.ff_model = torch.jit.load("model_scripted.pt")
|
17 |
|
@@ -91,7 +99,7 @@ class TextDetectionApp:
|
|
91 |
with torch.no_grad():
|
92 |
detection_score = self.ff_model(self.generate_ff_input(self.api_huggingface(text)))[0][0].item()
|
93 |
# Return result based on the score threshold
|
94 |
-
return "Generated" if detection_score > 0.5 else "Human-Written"
|
95 |
|
96 |
def classify_text(self, text, model_choice):
|
97 |
"""
|
@@ -99,7 +107,7 @@ class TextDetectionApp:
|
|
99 |
|
100 |
Args:
|
101 |
text (str): The input text to classify.
|
102 |
-
model_choice (str): The model to use ('DeBERTa', 'RoBERTa', or 'Feedforward').
|
103 |
|
104 |
Returns:
|
105 |
str: The classification result.
|
@@ -114,9 +122,8 @@ class TextDetectionApp:
|
|
114 |
# Get classification results
|
115 |
logits = outputs.logits
|
116 |
predicted_class_id = logits.argmax().item()
|
117 |
-
label = "Generated" if predicted_class_id == 1 else "Human-Written"
|
118 |
-
return f"
|
119 |
-
|
120 |
elif model_choice == 'RoBERTa':
|
121 |
# Tokenize input
|
122 |
inputs = self.roberta_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
@@ -127,13 +134,37 @@ class TextDetectionApp:
|
|
127 |
# Get classification results
|
128 |
logits = outputs.logits
|
129 |
predicted_class_id = logits.argmax().item()
|
130 |
-
label = "Generated" if predicted_class_id == 1 else "Human-Written"
|
131 |
-
return f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
elif model_choice == 'Feedforward':
|
134 |
# Run feedforward detection
|
135 |
detection_result = self.detect_text(text)
|
136 |
-
return f"
|
137 |
|
138 |
else:
|
139 |
return "Invalid model selection."
|
@@ -147,11 +178,11 @@ iface = gr.Interface(
|
|
147 |
fn=app.classify_text,
|
148 |
inputs=[
|
149 |
gr.Textbox(lines=2, placeholder="Enter your text here..."),
|
150 |
-
gr.Radio(choices=["DeBERTa", "RoBERTa", "Feedforward"], label="Model Choice")
|
151 |
],
|
152 |
outputs="text",
|
153 |
title="Text Classification with Multiple Models",
|
154 |
-
description="Classify text as generated or human-written using DeBERTa, RoBERTa, or a custom Feedforward model."
|
155 |
)
|
156 |
|
157 |
iface.launch()
|
|
|
12 |
self.roberta_tokenizer = AutoTokenizer.from_pretrained("zeyadusf/roberta-DAIGT-kaggle")
|
13 |
self.roberta_model = AutoModelForSequenceClassification.from_pretrained("zeyadusf/roberta-DAIGT-kaggle")
|
14 |
|
15 |
+
# Load BERT model and tokenizer
|
16 |
+
self.bert_tokenizer = AutoTokenizer.from_pretrained("zeyadusf/bert-DAIGT-MODELS")
|
17 |
+
self.bert_model = AutoModelForSequenceClassification.from_pretrained("zeyadusf/bert-DAIGT-MODELS")
|
18 |
+
|
19 |
+
# Load DistilBERT model and tokenizer
|
20 |
+
self.distilbert_tokenizer = AutoTokenizer.from_pretrained("zeyadusf/distilbert-DAIGT-MODELS")
|
21 |
+
self.distilbert_model = AutoModelForSequenceClassification.from_pretrained("zeyadusf/distilbert-DAIGT-MODELS")
|
22 |
+
|
23 |
# Load Feedforward model
|
24 |
self.ff_model = torch.jit.load("model_scripted.pt")
|
25 |
|
|
|
99 |
with torch.no_grad():
|
100 |
detection_score = self.ff_model(self.generate_ff_input(self.api_huggingface(text)))[0][0].item()
|
101 |
# Return result based on the score threshold
|
102 |
+
return "Generated Text" if detection_score > 0.5 else "Human-Written"
|
103 |
|
104 |
def classify_text(self, text, model_choice):
|
105 |
"""
|
|
|
107 |
|
108 |
Args:
|
109 |
text (str): The input text to classify.
|
110 |
+
model_choice (str): The model to use ('DeBERTa', 'RoBERTa', 'BERT', 'DistilBERT', or 'Feedforward').
|
111 |
|
112 |
Returns:
|
113 |
str: The classification result.
|
|
|
122 |
# Get classification results
|
123 |
logits = outputs.logits
|
124 |
predicted_class_id = logits.argmax().item()
|
125 |
+
label = "Generated Text" if predicted_class_id == 1 else "Human-Written"
|
126 |
+
return f"{label} )"
|
|
|
127 |
elif model_choice == 'RoBERTa':
|
128 |
# Tokenize input
|
129 |
inputs = self.roberta_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
|
|
134 |
# Get classification results
|
135 |
logits = outputs.logits
|
136 |
predicted_class_id = logits.argmax().item()
|
137 |
+
label = "Generated Text" if predicted_class_id == 1 else "Human-Written"
|
138 |
+
return f"{label} )"
|
139 |
+
elif model_choice == 'BERT':
|
140 |
+
# Tokenize input
|
141 |
+
inputs = self.bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
142 |
+
|
143 |
+
# Run model
|
144 |
+
outputs = self.bert_model(**inputs)
|
145 |
+
|
146 |
+
# Get classification results
|
147 |
+
logits = outputs.logits
|
148 |
+
predicted_class_id = logits.argmax().item()
|
149 |
+
label = "Generated Text" if predicted_class_id == 1 else "Human-Written"
|
150 |
+
return f"{label} )"
|
151 |
+
elif model_choice == 'DistilBERT':
|
152 |
+
# Tokenize input
|
153 |
+
inputs = self.distilbert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
154 |
+
|
155 |
+
# Run model
|
156 |
+
outputs = self.distilbert_model(**inputs)
|
157 |
+
|
158 |
+
# Get classification results
|
159 |
+
logits = outputs.logits
|
160 |
+
predicted_class_id = logits.argmax().item()
|
161 |
+
label = "Generated Text" if predicted_class_id == 1 else "Human-Written"
|
162 |
+
return f"{label} )"
|
163 |
|
164 |
elif model_choice == 'Feedforward':
|
165 |
# Run feedforward detection
|
166 |
detection_result = self.detect_text(text)
|
167 |
+
return f"{detection_result}"
|
168 |
|
169 |
else:
|
170 |
return "Invalid model selection."
|
|
|
178 |
fn=app.classify_text,
|
179 |
inputs=[
|
180 |
gr.Textbox(lines=2, placeholder="Enter your text here..."),
|
181 |
+
gr.Radio(choices=["DeBERTa", "RoBERTa", "BERT", "DistilBERT", "Feedforward"], label="Model Choice")
|
182 |
],
|
183 |
outputs="text",
|
184 |
title="Text Classification with Multiple Models",
|
185 |
+
description="Classify text as generated or human-written using DeBERTa, RoBERTa, BERT, DistilBERT, or a custom Feedforward model."
|
186 |
)
|
187 |
|
188 |
iface.launch()
|