# SciBERT Fine-Tuning 

---

In this notebook, we use the ðŸ¤— `transformers` library to fine-tune the `allenai/scibert_scivocab_uncased` model on various datasets. The goal is for the fine-tuned model to perform Named Entity Recognition of Drugs, Diseases, and Genes. 

!pip install --target=$"modules" datasets transformers
! pip install --target=$"modules" seqeval 
! pip install --target=$"modules" spacy 

! pip install --target=$"modules" evaluate

! pip install --target=$"modules" bioc

In [1]:
from datasets import Dataset, ClassLabel, Sequence, load_dataset, load_metric
import numpy as np
import pandas as pd
import bioc
from spacy import displacy
import transformers
#import evaluate
from transformers import (AutoModelForTokenClassification, 
                          AutoTokenizer, 
                          DataCollatorForTokenClassification,
                          pipeline,
                          TrainingArguments, 
                          Trainer)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# confirm version > 4.11.0
print(transformers.__version__)

4.31.0


## Reading files

In [3]:
# no train-test provided, so we create our own
cons_dataset = load_dataset("json", data_files="./allDataShort2.json")
#cons_dataset = datasets
cons_dataset = cons_dataset["train"].train_test_split(test_size=0.001)

In [4]:
cons_dataset

DatasetDict({
    train: Dataset({
        features: ['index', 'text', 'gene', 'gene_indices_start', 'gene_indices_end', 'disease', 'disease_indices_start', 'disease_indices_end', 'drugs', 'drug_indices_start', 'drug_indices_end'],
        num_rows: 6554
    })
    test: Dataset({
        features: ['index', 'text', 'gene', 'gene_indices_start', 'gene_indices_end', 'disease', 'disease_indices_start', 'disease_indices_end', 'drugs', 'drug_indices_start', 'drug_indices_end'],
        num_rows: 7
    })
})

---
## Token Labeling

Finally, we can label each token with its entity. We use BIO tagging on two entities, `DRUG` and `EFFECT`. This results in five possible classes for each token:

* `O` - outside any entity we care about
* `B-DRUG` - the beginning of a `DRUG` entity
* `I-DRUG` - inside a `DRUG` entity
* `B-EFFECT` - the beginning of an `EFFECT` entity
* `I-EFFECT` - inside an `EFFECT` entity

In [5]:
label_list = ['O', 'B-DRUG', 'I-DRUG', 'B-DISEASE', 'I-DISEASE', 'B-GENE', 'I-GENE']

custom_seq = Sequence(feature=ClassLabel(num_classes=7, 
                                         names=label_list,
                                         names_file=None, id=None), length=-1, id=None)

cons_dataset["train"].features["ner_tags"] = custom_seq
cons_dataset["test"].features["ner_tags"] = custom_seq

In [6]:
#model_checkpoint = "allenai/scibert_scivocab_uncased"
model_checkpoint = './trainedSB2'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [7]:
from math import isnan
def generate_row_labels(row, verbose=False):
    """ Given a row from the consolidated dataset, 
    generates BIO tags for drug and effect entities. 
    """

    text = row["text"]

    labels = []
    label = "O"
    prefix = ""
    
    # while iterating through tokens, increment to traverse all drug and effect spans
    drug_index = 0
    disease_index = 0
    gene_index=0
    
    tokens = tokenizer(text, return_offsets_mapping=True)

    for n in range(len(tokens["input_ids"])):
        offset_start, offset_end = tokens["offset_mapping"][n]

        # should only happen for [CLS] and [SEP]
        if offset_end - offset_start == 0:
            labels.append(-100)
            continue
        
        if (type(row["drug_indices_start"]) == list) and drug_index < len(row["drug_indices_start"]) and offset_start == row["drug_indices_start"][drug_index]:
            label = "DRUG"
            prefix = "B-"

        elif (type(row["disease_indices_start"]) == list) and disease_index < len(row["disease_indices_start"]) and offset_start == row["disease_indices_start"][disease_index]:
            label = "DISEASE"
            prefix = "B-"
            
        elif (type(row["gene_indices_start"]) == list) and gene_index < len(row["gene_indices_start"]) and offset_start == row["gene_indices_start"][gene_index]:
            label = "GENE"
            prefix = "B-"
        
        labels.append(label_list.index(f"{prefix}{label}"))
            
        if (type(row["drug_indices_end"]) == list) and drug_index < len(row["drug_indices_end"]) and offset_end == row["drug_indices_end"][drug_index]:
            label = "O"
            prefix = ""
            drug_index += 1
            
        elif (type(row["disease_indices_end"]) == list) and disease_index < len(row["disease_indices_end"]) and offset_end == row["disease_indices_end"][disease_index]:
            label = "O"
            prefix = ""
            disease_index += 1
            
        elif (type(row["gene_indices_end"]) == list) and gene_index < len(row["gene_indices_end"]) and offset_end == row["gene_indices_end"][gene_index]:
            label = "O"
            prefix = ""
            gene_index += 1

        # need to transition "inside" if we just entered an entity
        if prefix == "B-":
            prefix = "I-"
    
    if verbose:
        print(f"{row}\n")
        orig = tokenizer.convert_ids_to_tokens(tokens["input_ids"])
        for n in range(len(labels)):
            print(orig[n], labels[n])
    tokens["labels"] = labels
    
    return tokens


In [8]:
labeled_dataset = cons_dataset.map(generate_row_labels)

Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 6554/6554 [00:14<00:00, 452.94 examples/s]
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 7/7 [00:00<00:00, 102.88 examples/s]


---
## SciBERT Model Fine-Tuning

We are now ready to fine-tune the SciBERT model on our dataset. This section is modified from the following ðŸ¤— notebook provided here: https://github.com/huggingface/notebooks/blob/master/examples/token_classification.ipynb


In [9]:
task = "ner" # Should be one of "ner", "pos" or "chunk"
#model_checkpoint = "allenai/scibert_scivocab_uncased"
batch_size = 16

In [10]:
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))

In [48]:
#!pip install transformers[torch]
model_name = model_checkpoint.split("/")[-1]
args = TrainingArguments(
    f"{model_name}-finetuned-{task}",
    evaluation_strategy = "epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.05,
    logging_steps=1
)

In [49]:
data_collator = DataCollatorForTokenClassification(tokenizer)

In [50]:
#!pip install seqeval
#module = evaluate.load("lvwerra/element_count")
metric = load_metric("seqeval")

In [51]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [52]:
trainer = Trainer(
    model,
    args,
    train_dataset=labeled_dataset["train"],
    eval_dataset=labeled_dataset["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics, 

)


In [53]:
predictions, labels, _ = trainer.predict(labeled_dataset["test"])
predictions = np.argmax(predictions, axis=2)

# Remove ignored index (special tokens)
true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
results

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'DRUG': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 3},
 'GENE': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 13},
 'overall_precision': 1.0,
 'overall_recall': 1.0,
 'overall_f1': 1.0,
 'overall_accuracy': 1.0}

---
## See Model Outputs

We load our fine-tuned model into a `pipeline` object to run arbitrary input against it.

In [11]:
effect_ner_model = pipeline(task="ner", model=model, tokenizer=tokenizer)

In [12]:
# something from our validation set
effect_ner_model(labeled_dataset["test"][6]["text"])

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


[{'entity': 'LABEL_0',
  'score': 0.9997309,
  'index': 1,
  'word': 'clinical',
  'start': 0,
  'end': 8},
 {'entity': 'LABEL_0',
  'score': 0.9997073,
  'index': 2,
  'word': 'studies',
  'start': 9,
  'end': 16},
 {'entity': 'LABEL_0',
  'score': 0.9996835,
  'index': 3,
  'word': 'have',
  'start': 17,
  'end': 21},
 {'entity': 'LABEL_0',
  'score': 0.99972457,
  'index': 4,
  'word': 'indicated',
  'start': 22,
  'end': 31},
 {'entity': 'LABEL_0',
  'score': 0.9997483,
  'index': 5,
  'word': 'the',
  'start': 32,
  'end': 35},
 {'entity': 'LABEL_0',
  'score': 0.9997286,
  'index': 6,
  'word': 'following',
  'start': 36,
  'end': 45},
 {'entity': 'LABEL_0',
  'score': 0.9997805,
  'index': 7,
  'word': 'relative',
  'start': 46,
  'end': 54},
 {'entity': 'LABEL_0',
  'score': 0.9997863,
  'index': 8,
  'word': 'potency',
  'start': 55,
  'end': 62},
 {'entity': 'LABEL_0',
  'score': 0.9997502,
  'index': 9,
  'word': 'differences',
  'start': 63,
  'end': 74},
 {'entity': 'LABEL

---
We try out the first few examples of adverse effects from the Wikipedia page on adverse effects and visualize with the displaCy library:

https://en.wikipedia.org/wiki/Adverse_effect#Medications

In [13]:
def visualize_entities(sentence):
    tokens = effect_ner_model(sentence)
    entities = []
    # ['O', 'B-DRUG', 'I-DRUG', 'B-DISEASE', 'I-DISEASE', 'B-GENE', 'I-GENE']
    for token in tokens:
        label = int(token["entity"][-1])
        if label != 0:
            token["label"] = label_list[label]
            entities.append(token)
    
    params = [{"text": sentence,
               "ents": entities,
               "title": None}]
    
    html = displacy.render(params, style="ent", manual=True, options={
        "colors": {
                    
                   "B-DRUG": "#f08080",
                   "I-DRUG": "#f08080",
                   "B-DISEASE": "#9bddff",
                   "I-DISEASE": "#9bddff",
                    "B-GENE": "#008080",
                   "I-GENE": "#008080",
               },
    })
    return html

In [14]:
import gradio as gr

exampleList = [
    'Famotidine is a histamine H2-receptor antagonist used in inpatient settings for prevention of stress ulcers and is showing increasing popularity because of its low cost.',
    'A randomized Phase III trial demonstrated noninferiority of APF530 500 mg SC ( granisetron 10 mg ) to intravenous palonosetron 0.25 mg in preventing CINV in patients receiving MEC or HEC in acute ( 0 - 24 hours ) and delayed ( 24 - 120 hours ) settings , with activity over 120 hours .',
    'What are the known interactions between Aspirin and the COX-1 enzyme?',
    'Can you explain the mechanism of action of Metformin and its effect on the AMPK pathway?',
    'Are there any genetic variations in the CYP2C9 gene that may influence the response to Warfarin therapy?',
    'I am curious about the role of Herceptin in targeting the HER2/neu protein in breast cancer treatment. How does it work?',
    'What are the common side effects associated with Lisinopril, an angiotensin-converting enzyme (ACE) inhibitor?',
    'Can you explain the significance of the BCR-ABL fusion protein in the context of Imatinib therapy for chronic myeloid leukemia (CML)?',
    'How does Ibuprofen affect the COX-2 enzyme compared to COX-1?',
    'Are there any recent studies exploring the use of Pembrolizumab as an immune checkpoint inhibitor targeting PD-1?',
    'I have heard about the SLC6A4 gene and its association with serotonin reuptake inhibitors (SSRIs) like Fluoxetine.',
    'Could you provide insights into the BRAF mutation and its relevance in response to Vemurafenib treatment in melanoma patients?'
]

footer = """
LLMGeneLinker uses a domain-specific transformer like SciBERT finetuned on AllenAI drug dataset, BC5CDR disease, NCBI disease, DrugProt and GeneTAG datasets. The resulting SciBERT model performs Named Entity Recognition to tag drug, protein, gene, diseases in input text. Sentence embedding of SciBERT is then fed into BERT 
This was made during the <a target="_blank" href =https://www.sginnovate.com/event/hackathon-large-language-models-bio> LLMs for Bio Hackathon</a> organised by 4Catalyzer and SGInnovate.
<br>
Made by Team GeneLink (<a target="_blank" href=https://www.linkedin.com/in/ntkb/>Nicholas</a>, <a target="_blank" href=https://www.linkedin.com/in/yewchong-sim/>Yew Chong</a>, <a target="_blank" href=https://www.linkedin.com/in/lim-ting-wei-021383175/>Ting Wei</a>,  <a target="_blank" href=https://www.linkedin.com/in/brendan-lim-ciwen/>Brendan</a>
<hr>
Note: Performance is noted to be poorer on genes, acronyms, and receptors (named entities that may be targets for drugs or genes)
Original notebook adapted from <a target="_blank" href=https://huggingface.co/jsylee/scibert_scivocab_uncased-finetuned-ner>jsylee/scibert_scivocab_uncased-finetuned-ner</a>
"""

with gr.Blocks() as demo:
    gr.Markdown("## LLMGeneLinker (LGL)")
    gr.Markdown(footer)
    
    txt = gr.Textbox(label="Input", lines=2)
    txt_3 = gr.HTML(label="Output")
    btn = gr.Button(value="Submit")
    btn.click(visualize_entities, inputs=txt, outputs=txt_3)

    gr.Markdown("## Text Examples")
    gr.Examples(
        [[x] for x in exampleList],
        txt,
        txt_3,
        visualize_entities,
        cache_examples=False,
        run_on_click=True
    )


if __name__ == "__main__":
    demo.launch()

Running on local URL:  http://127.0.0.1:7860

Thanks for being a Gradio user! If you have questions or feedback, please join our Discord server and chat with us: https://discord.gg/feTf9x3ZSB

To create a public link, set `share=True` in `launch()`.
