cheberle commited on
Commit
c60d44f
·
1 Parent(s): 84d29be
Files changed (1) hide show
  1. app.py +34 -63
app.py CHANGED
@@ -1,84 +1,55 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
4
 
5
  # ---------------------------------------------------------------------------
6
- # 1) Load the model and tokenizer
7
  # ---------------------------------------------------------------------------
8
- # If you want to load in 8-bit or 4-bit precision with bitsandbytes,
9
- # uncomment and install bitsandbytes, and set load_in_8bit=True or load_in_4bit=True.
10
- # For example:
11
- #
12
- # from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
13
- # bnb_config = BitsAndBytesConfig(
14
- # load_in_4bit=True, # or load_in_8bit=True
15
- # bnb_4bit_compute_dtype=torch.float16, # recommended compute dtype
16
- # bnb_4bit_use_double_quant=True, # optional
17
- # bnb_4bit_quant_type='nf4', # optional
18
- # )
19
- #
20
- # model = AutoModelForCausalLM.from_pretrained(
21
- # "cheberle/autotrain-35swc-b4r9z",
22
- # quantization_config=bnb_config,
23
- # device_map="auto",
24
- # trust_remote_code=True
25
- # )
26
- # tokenizer = AutoTokenizer.from_pretrained("cheberle/autotrain-35swc-b4r9z", trust_remote_code=True)
27
 
28
- # For a standard FP16 or FP32 load (no bitsandbytes):
29
- model = AutoModelForCausalLM.from_pretrained(
30
- "cheberle/autotrain-35swc-b4r9z",
31
- torch_dtype=torch.float16, # Or "auto", or float32
32
- trust_remote_code=True
33
- )
34
- tokenizer = AutoTokenizer.from_pretrained(
35
- "cheberle/autotrain-35swc-b4r9z",
36
- trust_remote_code=True
37
- )
38
 
39
  # ---------------------------------------------------------------------------
40
- # 2) Define a text generation function
41
  # ---------------------------------------------------------------------------
42
- def generate_text(prompt):
43
- # Tokenize input
44
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
45
-
46
- # Generate output (configure generation args as needed)
47
- with torch.no_grad():
48
- outputs = model.generate(
49
- **inputs,
50
- max_new_tokens=128,
51
- temperature=0.7,
52
- top_p=0.9,
53
- do_sample=True
54
- )
55
-
56
- # Decode
57
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
58
- return decoded
59
 
60
  # ---------------------------------------------------------------------------
61
- # 3) Create the Gradio interface
62
  # ---------------------------------------------------------------------------
63
  with gr.Blocks() as demo:
64
- gr.Markdown("<h3>Demo: cheberle/autotrain-35swc-b4r9z</h3>")
65
 
66
  with gr.Row():
67
- with gr.Column():
68
- prompt_in = gr.Textbox(
69
- lines=5,
70
- label="Enter your prompt",
71
- placeholder="Ask something here..."
72
- )
73
- submit_btn = gr.Button("Generate")
74
- with gr.Column():
75
- output_box = gr.Textbox(lines=15, label="Model Output")
76
 
77
- # Define what happens on button click
78
- submit_btn.click(fn=generate_text, inputs=prompt_in, outputs=output_box)
79
 
80
  # ---------------------------------------------------------------------------
81
- # 4) Launch!
82
  # ---------------------------------------------------------------------------
83
  if __name__ == "__main__":
84
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
4
+
5
+ MODEL_NAME = "cheberle/autotrain-35swc-b4r9z"
6
 
7
  # ---------------------------------------------------------------------------
8
+ # 1) Load the tokenizer and model for sequence classification
9
  # ---------------------------------------------------------------------------
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
11
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Create a pipeline for text classification
14
+ classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
 
15
 
16
  # ---------------------------------------------------------------------------
17
+ # 2) Define inference function
18
  # ---------------------------------------------------------------------------
19
+ def classify_text(text):
20
+ """
21
+ Return the classification results in the format:
22
+ [
23
+ {
24
+ 'label': 'POSITIVE',
25
+ 'score': 0.98
26
+ }
27
+ ]
28
+ """
29
+ results = classifier(text)
30
+ return results
 
 
 
 
 
31
 
32
  # ---------------------------------------------------------------------------
33
+ # 3) Build the Gradio UI
34
  # ---------------------------------------------------------------------------
35
  with gr.Blocks() as demo:
36
+ gr.Markdown("<h3>Text Classification Demo</h3>")
37
 
38
  with gr.Row():
39
+ input_text = gr.Textbox(
40
+ lines=3,
41
+ label="Enter text to classify",
42
+ placeholder="Type something..."
43
+ )
44
+ output = gr.JSON(label="Classification Output")
45
+
46
+ classify_btn = gr.Button("Classify")
 
47
 
48
+ # Link the button to the function
49
+ classify_btn.click(fn=classify_text, inputs=input_text, outputs=output)
50
 
51
  # ---------------------------------------------------------------------------
52
+ # 4) Launch the demo
53
  # ---------------------------------------------------------------------------
54
  if __name__ == "__main__":
55
  demo.launch()