spawn99 commited on
Commit
ba33d46
·
verified ·
1 Parent(s): f5bde00

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import gradio as gr
4
+
5
+
6
+ def run_inference(review_text: str) -> str:
7
+ """
8
+ Perform inference on the given wine review text and return the predicted wine variety.
9
+
10
+ Args:
11
+ review_text (str): Wine review text in the format "country [SEP] description".
12
+
13
+ Returns:
14
+ str: The predicted wine variety using the model's id2label mapping if available.
15
+ """
16
+ # Define model and tokenizer identifiers
17
+ model_id = "spawn99/modernbert-wine-classification"
18
+ tokenizer_id = "answerdotai/ModernBERT-base"
19
+
20
+ # Load tokenizer and model
21
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
22
+ model = AutoModelForSequenceClassification.from_pretrained(model_id)
23
+
24
+ # Tokenize the input text
25
+ inputs = tokenizer(
26
+ review_text,
27
+ return_tensors="pt",
28
+ padding="max_length",
29
+ truncation=True,
30
+ max_length=256
31
+ )
32
+
33
+ model.eval()
34
+ with torch.no_grad():
35
+ outputs = model(**inputs)
36
+ logits = outputs.logits
37
+
38
+ # Determine prediction and map to label if available
39
+ pred = torch.argmax(logits, dim=-1).item()
40
+ variety = (
41
+ model.config.id2label.get(pred, str(pred))
42
+ if hasattr(model.config, "id2label") and model.config.id2label
43
+ else str(pred)
44
+ )
45
+
46
+ return variety
47
+
48
+
49
+ def predict_wine_variety(country: str, description: str) -> dict:
50
+ """
51
+ Combine the provided country and description, then perform inference.
52
+ Enforces a maximum character limit of 750 on the description.
53
+
54
+ Args:
55
+ country (str): The country of wine origin.
56
+ description (str): The wine review description.
57
+
58
+ Returns:
59
+ dict: Dictionary containing the predicted wine variety or an error message if the limit is exceeded.
60
+ """
61
+ # Validate description length
62
+ if len(description) > 750:
63
+ return {"error": "Description exceeds 750 character limit. Please shorten your input."}
64
+
65
+ # Capitalize input values and format the review text accordingly.
66
+ review_text = f"{country.capitalize()} [SEP] {description.capitalize()}"
67
+ predicted_variety = run_inference(review_text)
68
+ return {"Variety": predicted_variety}
69
+
70
+
71
+ if __name__ == "__main__":
72
+ iface = gr.Interface(
73
+ fn=predict_wine_variety,
74
+ inputs=[
75
+ gr.Textbox(label="Country", placeholder="Enter country of origin..."),
76
+ gr.Textbox(label="Description", placeholder="Enter wine review description...")
77
+ ],
78
+ outputs=gr.JSON(label="Prediction"),
79
+ title="Wine Variety Predictor",
80
+ description="Predict the wine variety based on country and description.",
81
+ flagging="never"
82
+ )
83
+ iface.launch()