Flan-T5 (base-sized) Dialogue Summarization with reduced toxicity using RLAIF

This model is a two-fold fine-tuned Flan-T5 model firstly on the SAMSUM dataset followed by further fine-tuning using Reinforcement Learning from AI Feedback(RLAIF) to detoxify model outputs.
Anthropic's Costitutional AI paper from 2022, provides some amazing insights on how RLAIF can be leveraged. Do check out if interested!

More specifically, I've fine-tuned this model on a single downstream task of Dialogue Summarization on the above mentioned dataset with a primary objective of reduced toxicity in generated summaries.

Model description

This Model has the same architecture and Parameters as its base model. Please refer to this link to know more about the model details.

Intended Use & Limitations

This model is intended to summarize the given dialogue in a way that outputs the less toxic summary even when we pass a dialogue that contains toxic phrases or words.
I've fine-tuned the model with an instruction of Summarize the following Conversation: that's prepended at the start of each dialogue followed by Summary: keyword at the end that indicates the start of summary.

Note:

  1. The model is primarily trained with an objective of reduced toxicity in the outputs, we can sometimes expect relatively short outputs that might sometimes(rarely) miss the important message in the dialogue but still being true to its primary goal.
  2. Currently, HuggingFace doesn't support PEFT model files for Text2Text-Generation Pipeline directly as Hosted Inference API, so please follow the steps mentioned below in the Usage section to load and use the model.

Usage

You can use this model directly to get the summaries:

import torch

from peft import PeftModel, PeftConfig

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


# Load peft config for pre-trained checkpoint etc.
peft_model_id = "DeathReaper0965/flan-t5-samsum-lora-RLAIF-detoxified"
config = PeftConfig.from_pretrained(peft_model_id)

# load base LLM model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, device_map='auto') # If required, you can add `load_in_8bit=True` for loading model in 8-bit
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id, device_map='auto')

input_ids = tokenizer.encode(
              "Summarize the following Conversation: Dean: I feel sick Scott: hungover? Dean: no, like I ate something bad Scott: what did you eat yesterday? Dean: breakfast at Coffee Lovers' Scott: this is a rather safe place Dean: and Chinese from TaoTao for dinner Scott: now we have a suspect Summary:",
              return_tensors="pt"
            ).to("cuda" if torch.cuda.is_available() else "cpu")

summary = model.generate(
            input_ids = input_ids,
            max_new_tokens=256,
            repetition_penalty=2.5,
            top_p=0.95,
            top_k=50, 
            temperature=0.6,
            no_repeat_ngram_size=2,
            num_return_sequences=1,
            do_sample=True)

output = tokenizer.batch_decode(summary, skip_special_tokens=True)

###########OUTPUT###########
# "Dean ate breakfast at Coffee Lovers' yesterday and Chinese from TaoTao for dinner."

Designed and Developed with ♥ by Praneet | LinkedIn | GitHub

Downloads last month
34
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train DeathReaper0965/flan-t5-samsum-lora-RLAIF-detoxified