Edit model card

Model Card

Base Model: facebook/bart-base

Fine-tuned : using PEFT-LoRa

Datasets : squad_v2, drop

Task: Generating questions from context and answers

Language: English

Loading the model

  from peft import PeftModel, PeftConfig
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
  HUGGING_FACE_USER_NAME = "mou3az"
  model_name = "Question-Generation"
  peft_model_id = f"{HUGGING_FACE_USER_NAME}/{model_name}"
  config = PeftConfig.from_pretrained(peft_model_id)
  model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=False, device_map='auto')
  QG_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
  QG_model = PeftModel.from_pretrained(model, peft_model_id)

At inference time

  def get_question(context, answer):
      device = next(QG_model.parameters()).device
      input_text = f"Given the context '{context}' and the answer '{answer}', what question can be asked?"
      encoding = QG_tokenizer.encode_plus(input_text, padding=True, return_tensors="pt").to(device)
  
      output_tokens = QG_model.generate(**encoding, early_stopping=True, num_beams=5, num_return_sequences=1, no_repeat_ngram_size=2, max_length=100)
      out = QG_tokenizer.decode(output_tokens[0], skip_special_tokens=True).replace("question:", "").strip()
  
      return out

Training parameters and hyperparameters

The following were used during training:

For Lora:

r=18

alpha=8

For training arguments:

gradient_accumulation_steps=16

per_device_train_batch_size=8

per_device_eval_batch_size=8

max_steps=3000

warmup_steps=75

weight_decay=0.05

learning_rate=1e-3

lr_scheduler_type="linear"

Performance Metrics on Evaluation Set:

for 3000 optimization steps:

Training Loss: 1.292400

Evaluation Loss: 1.244928  

Bertscore: 0.8123

Rouge: 0.532144

Fuzzywizzy similarity: 0.74209
Downloads last month
12
Inference API
Unable to determine this model’s pipeline type. Check the docs .

Model tree for mou3az/QuestionGeneration

Base model

facebook/bart-base
Adapter
(31)
this model

Datasets used to train mou3az/QuestionGeneration