File size: 5,244 Bytes
8a97d04
 
 
 
 
 
 
 
 
 
5606661
75ec169
cbbbf96
75ec169
8a97d04
 
5b2bbc4
8a97d04
3554c5e
8a97d04
 
 
92a6ed6
8a97d04
 
 
 
 
9345a63
 
 
 
8a97d04
9345a63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a97d04
9345a63
 
dd3ac8b
9345a63
 
 
 
66d4b81
 
 
 
 
9345a63
 
 
 
 
 
8a97d04
 
 
 
 
 
 
 
1de5d78
92a6ed6
1de5d78
 
92a6ed6
 
1de5d78
 
8a97d04
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
---
license: mit
language:
- en
tags:
- medical
- radiology
model-index:
- name: rate-ner-rad
  results: []
pipeline_tag: token-classification
widget:
- text:  No suspicious focal mass lesion is seen in the left kidney.
  example_title: Example in radiopaedia
---

# RaTE-NER-Deberta

This model is a fine-tuned version of [DeBERTa](https://huggingface.co/microsoft/deberta-v3-base) on the [RaTE-NER](https://huggingface.co/datasets/Angelakeke/RaTE-NER/) dataset.

## Model description

This model is trained to serve the RaTEScore metric, if you are interested in our pipeline, please refer to our [paper](https://aclanthology.org/2024.emnlp-main.836.pdf) and [Github](https://github.com/Angelakeke/RaTEScore).

This model also can be used to extract  **Abnormality, Non-Abnormality, Anatomy, Disease, Non-Disease**
in medical radiology reports.

## Usage

<details>
  <summary> Click to expand the usage of this model. </summary>
<pre><code>
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
def post_process(tokenized_text, predicted_entities, tokenizer):
    entity_spans = []
    start = end = None
    entity_type = None
    for i, (token, label) in enumerate(zip(tokenized_text, predicted_entities[:len(tokenized_text)])):
        if token in ["[CLS]", "[SEP]"]:
            continue
        if label != "O" and i < len(predicted_entities) - 1:
            if label.startswith("B-") and predicted_entities[i+1].startswith("I-"):
                start = i
                entity_type = label[2:]
            elif label.startswith("B-") and predicted_entities[i+1].startswith("B-"):
                start = i
                end = i
                entity_spans.append((start, end, label[2:]))
                start = i
                entity_type = label[2:]
            elif label.startswith("B-") and predicted_entities[i+1].startswith("O"):
                start = i
                end = i
                entity_spans.append((start, end, label[2:]))
                start = end = None
                entity_type = None
            elif label.startswith("I-") and predicted_entities[i+1].startswith("B-"):
                end = i
                if start is not None:
                    entity_spans.append((start, end, entity_type))
                start = i
                entity_type = label[2:]
            elif label.startswith("I-") and predicted_entities[i+1].startswith("O"):
                end = i
                if start is not None:
                    entity_spans.append((start, end, entity_type))
                start = end = None
                entity_type = None
    if start is not None and end is None:
        end = len(tokenized_text) - 2
        entity_spans.append((start, end, entity_type))
    save_pair = []
    for start, end, entity_type in entity_spans:
        entity_str = tokenizer.convert_tokens_to_string(tokenized_text[start:end+1])
        save_pair.append((entity_str, entity_type))
    return save_pair

def run_ner(texts, idx2label, tokenizer, model, device):
    inputs = tokenizer(texts, 
                    max_length=512,
                    padding=True, 
                    truncation=True, 
                    return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    predicted_labels = torch.argmax(outputs.logits, dim=2).tolist()
    save_pairs = []
    for i in range(len(texts)):
        predicted_entities = [idx2label[label] for label in predicted_labels[i]]
        non_pad_mask = inputs["input_ids"][i] != tokenizer.pad_token_id
        non_pad_length = non_pad_mask.sum().item()
        non_pad_input_ids = inputs["input_ids"][i][:non_pad_length]
        tokenized_text = tokenizer.convert_ids_to_tokens(non_pad_input_ids)
        save_pair = post_process(tokenized_text, predicted_entities, tokenizer)
        if i == 0:
            save_pairs = save_pair
        else:
            save_pairs.extend(save_pair)
    return save_pairs

ner_labels = ['B-ABNORMALITY', 'I-ABNORMALITY', 
              'B-NON-ABNORMALITY', 'I-NON-ABNORMALITY', 
              'B-DISEASE', 'I-DISEASE', 
              'B-NON-DISEASE', 'I-NON-DISEASE', 
              'B-ANATOMY', 'I-ANATOMY', 
              'O']
idx2label = {i: label for i, label in enumerate(ner_labels)}

tokenizer = AutoTokenizer.from_pretrained('Angelakeke/RaTE-NER-Deberta')
model = AutoModelForTokenClassification.from_pretrained('Angelakeke/RaTE-NER-Deberta')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

We recommend to inference by sentences.

text = ""

texts = text.split('. ')
save_pair = run_ner(texts, idx2label, tokenizer, model, device)

 </code></pre>

</details>


## Author

Author: [Weike Zhao](https://angelakeke.github.io/)

If you have any questions, please feel free to contact [email protected].

## Citation
```bibtex
@inproceedings{zhao2024ratescore,
  title={RaTEScore: A Metric for Radiology Report Generation},
  author={Zhao, Weike and Wu, Chaoyi and Zhang, Xiaoman and Zhang, Ya and Wang, Yanfeng and Xie, Weidi},
  booktitle={Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing},
  pages={15004--15019},
  year={2024}
}
```