spawn99 commited on
Commit
9ea2a9b
·
verified ·
1 Parent(s): c043a92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -10
app.py CHANGED
@@ -1,11 +1,13 @@
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import gradio as gr
 
4
 
5
 
6
  def run_inference(review_text: str) -> str:
7
  """
8
- Perform inference on the given wine review text and return the predicted wine variety.
 
9
 
10
  Args:
11
  review_text (str): Wine review text in the format "country [SEP] description".
@@ -19,6 +21,7 @@ def run_inference(review_text: str) -> str:
19
 
20
  # Load tokenizer and model
21
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
 
22
  model = AutoModelForSequenceClassification.from_pretrained(model_id)
23
 
24
  # Tokenize the input text
@@ -46,26 +49,37 @@ def run_inference(review_text: str) -> str:
46
  return variety
47
 
48
 
49
- def predict_wine_variety(country: str, description: str) -> dict:
50
  """
51
- Combine the provided country and description, then perform inference.
 
 
52
  Enforces a maximum character limit of 750 on the description.
53
 
54
  Args:
55
  country (str): The country of wine origin.
56
  description (str): The wine review description.
 
 
57
 
58
  Returns:
59
- dict: Dictionary containing the predicted wine variety or an error message if the limit is exceeded.
60
  """
61
- # Validate description length
62
  if len(description) > 750:
63
- return {"error": "Description exceeds 750 character limit. Please shorten your input."}
 
 
 
 
64
 
65
  # Capitalize input values and format the review text accordingly.
66
  review_text = f"{country.capitalize()} [SEP] {description.capitalize()}"
67
  predicted_variety = run_inference(review_text)
68
- return {"Variety": predicted_variety}
 
 
 
 
69
 
70
 
71
  if __name__ == "__main__":
@@ -73,10 +87,18 @@ if __name__ == "__main__":
73
  fn=predict_wine_variety,
74
  inputs=[
75
  gr.Textbox(label="Country", placeholder="Enter country of origin..."),
76
- gr.Textbox(label="Description", placeholder="Enter wine review description...")
 
 
77
  ],
78
- outputs=gr.JSON(label="Prediction"),
 
79
  title="Wine Variety Predictor",
80
- description="Predict the wine variety based on country and description."
 
 
 
 
 
81
  )
82
  iface.launch()
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import gradio as gr
4
+ import json # Added for JSON conversion
5
 
6
 
7
  def run_inference(review_text: str) -> str:
8
  """
9
+ Perform inference on the given wine review text and return the predicted wine variety
10
+ using ModernBERT, an encoder-only classifier from "spawn99/modernbert-wine-classification".
11
 
12
  Args:
13
  review_text (str): Wine review text in the format "country [SEP] description".
 
21
 
22
  # Load tokenizer and model
23
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
24
+ # The model used here is a ModernBERT encoder-only classifier.
25
  model = AutoModelForSequenceClassification.from_pretrained(model_id)
26
 
27
  # Tokenize the input text
 
49
  return variety
50
 
51
 
52
+ def predict_wine_variety(country: str, description: str, output_format: str) -> str:
53
  """
54
+ Combine the provided country and description, perform inference, and format the output
55
+ based on the selected output format.
56
+
57
  Enforces a maximum character limit of 750 on the description.
58
 
59
  Args:
60
  country (str): The country of wine origin.
61
  description (str): The wine review description.
62
+ output_format (str): Either "JSON" to return output as a JSON-formatted string,
63
+ or "Text" for plain text output.
64
 
65
  Returns:
66
+ str: The predicted wine variety formatted as JSON (if selected) or as plain text.
67
  """
 
68
  if len(description) > 750:
69
+ error_msg = "Description exceeds 750 character limit. Please shorten your input."
70
+ if output_format.lower() == "json":
71
+ return json.dumps({"error": error_msg}, indent=2)
72
+ else:
73
+ return error_msg
74
 
75
  # Capitalize input values and format the review text accordingly.
76
  review_text = f"{country.capitalize()} [SEP] {description.capitalize()}"
77
  predicted_variety = run_inference(review_text)
78
+
79
+ if output_format.lower() == "json":
80
+ return json.dumps({"Variety": predicted_variety}, indent=2)
81
+ else:
82
+ return predicted_variety
83
 
84
 
85
  if __name__ == "__main__":
 
87
  fn=predict_wine_variety,
88
  inputs=[
89
  gr.Textbox(label="Country", placeholder="Enter country of origin..."),
90
+ gr.Textbox(label="Description", placeholder="Enter wine review description..."),
91
+ # New radio input to choose between JSON and plain text output formats:
92
+ gr.Radio(choices=["JSON", "Text"], value="JSON", label="Output Format")
93
  ],
94
+ # Changed outputs to a Textbox so that plain text output shows naturally
95
+ outputs=gr.Textbox(label="Prediction"),
96
  title="Wine Variety Predictor",
97
+ description=(
98
+ "Predict the wine variety based on the country and wine review.\n\n"
99
+ "This tool uses ModernBERT, an encoder-only classifier, trained on the wine reviews dataset\n"
100
+ "(model: spawn99/modernbert-wine-classification, dataset: spawn99/wine-reviews).\n\n"
101
+ "Use the Output Format selector to toggle between a JSON-formatted result and a plain text prediction."
102
+ )
103
  )
104
  iface.launch()