meta-llama-3.1-segment-ppo Model Card
The meta-llama-3.1-segment-ppo model introduces a segment-level reward model to improve reinforcement learning with human feedback (RLHF) in language models. This work builds upon the methods in our paper Segmenting Text and Learning Their Rewards for Improved RLHF in Language Model.
Method Illustration
Below is an illustration of the segment-based reward modeling method, showing how entropy thresholds are used for segmentation, integrating both the reward model and PPO training:
Architecture
Model Overview
This approach redefines the granularity of RLHF training by:
- Assigning rewards to semantically complete text segments, defined based on entropy thresholds.
- Introducing techniques to stabilize RLHF training under dense, segment-level rewards.
Model checkpoints are available on HuggingFace.
Training Data
We utilize the following datasets in our training pipeline:
- Preference-700K Dataset: A diverse collection of open-source preference datasets, including HH-RLHF, Stanford Human Preferences Dataset (SHP), and HelpSteer.
- Ultrafeedback Dataset: Used for sampling prompts during the PPO training routine.
Base Model
The phi-instruct-segment-ppo model is fine-tuned from meta-llama/Llama-3.1-8B-Instruct.
Usage
You can use this model directly with Hugging Face's Transformers library:
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model and tokenizer
model_name = "yyqoni/meta-llama-3.1-instruct-8b-segment-ppo-60k"
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Input text
input_text = "What are the benefits of using reinforcement learning in AI?"
# Apply chat template formatting with generation prompt
formatted_input = tokenizer.apply_chat_template(
[{"role": "user", "content": input_text}],
tokenize=False,
add_generation_prompt=True
)
# Tokenize the formatted input
inputs = tokenizer(formatted_input, return_tensors="pt", add_special_tokens=False)
# Generate response
outputs = model.generate(**inputs, max_new_tokens=50)
# Decode and print the response
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Citation
If you find this model or our research useful, please consider citing our paper:
@misc{yin2025segmentingtextlearningrewards,
title={Segmenting Text and Learning Their Rewards for Improved RLHF in Language Model},
author={Yueqin Yin and Shentao Yang and Yujia Xie and Ziyi Yang and Yuting Sun and Hany Awadalla and Weizhu Chen and Mingyuan Zhou},
year={2025},
eprint={2501.02790},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2501.02790},
}
- Downloads last month
- 7
Model tree for yyqoni/meta-llama-3.1-instruct-8b-segment-ppo-60k
Base model
meta-llama/Llama-3.1-8B