File size: 3,516 Bytes
aea9d52
c4dd8e7
aea9d52
 
 
 
2f6ade8
 
 
aea9d52
 
 
 
2f6ade8
 
 
 
 
 
 
49ce6a9
 
2f6ade8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49ce6a9
 
 
 
 
 
 
 
e328eaa
 
dd12fa7
74f91bd
dd12fa7
 
74f91bd
dd12fa7
 
74f91bd
dd12fa7
 
 
74f91bd
dd12fa7
 
 
 
 
e328eaa
dd12fa7
 
74f91bd
dd12fa7
 
74f91bd
dd12fa7
 
 
 
 
 
74f91bd
aea9d52
 
 
 
 
 
 
7768e0f
aea9d52
7768e0f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import gradio as gr

# Load model directly
from transformers import AutoTokenizer, AutoModelForSequenceClassification

import torch
import transformers

tokenizer = AutoTokenizer.from_pretrained("nebiyu29/fintunned-v2-roberta_GA")
model = AutoModelForSequenceClassification.from_pretrained("nebiyu29/fintunned-v2-roberta_GA")


# Load the model and tokenizer
# model = transformers.AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli")

# tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/bart-large-mnli")

# Define a function to split a text into segments of 512 tokens
def split_text(text):
    #this prints progress
    print("going to split the text")
    # Tokenize the text
    tokens = tokenizer.tokenize(text)
    # Initialize an empty list for segments
    segments = []
    # Initialize an empty list for current segment
    current_segment = []
    # Initialize a counter for tokens
    token_count = 0
    # Loop through the tokens
    for token in tokens:
        # Add the token to the current segment
        current_segment.append(token)
        # Increment the token count
        token_count += 1
        # If the token count reaches 512 or the end of the text, add the current segment to the segments list
        if token_count == 512 or token == tokens[-1]:
            # Convert the current segment to a string and add it to the segments list
            segments.append(tokenizer.convert_tokens_to_string(current_segment))
            # Reset the current segment and the token count
            current_segment = []
            token_count = 0
    # Return the segments list
    return segments

# Define a function to extract predictions from model output (adjust as needed)
def extract_predictions(outputs):
    # Assuming outputs contain logits and labels (adapt based on your model's output format)
    logits = outputs.logits
    probs = logits.softmax(dim=1)
    preds = torch.argmax(probs, dim=1)
    return probs, preds  # Return all probabilities and predicted labels
    
# a function that classifies text

def classify_text(text):

    # Split text into segments using split_text
    segments = split_text(text)

    # Initialize empty list for predictions
    predictions = []

    # Move device to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Loop through segments, process, and store predictions
    for segment in segments:
        inputs = tokenizer([segment], padding=True, return_tensors="pt")
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask)

        # Extract predictions for each segment
        probs, preds = extract_predictions(outputs)  # Define this function based on your model's output

        # Append predictions for this segment
        predictions.append({
            "segment_text": segment,
            "label": preds[0],  # Assuming single label prediction
            "probability": probs[preds[0]]  # Access probability for the predicted label
        })


interface = gr.Interface(
    fn=classify_text,
    inputs="text",
    outputs="text",
    title="Text Classification Demo",
    description="Enter some text, and the model will classify it.",
   )

#interface.launch(server_port=8080)