|
from datasets import load_dataset |
|
from unfat.datasets import hub_prompts, HubSplit, Dataset, Prompts |
|
from unfat.extract import Extractor, ClientOpts |
|
from unfat.lora import LoraSettings |
|
import os |
|
|
|
output_dir="output" |
|
uncensor_ds_name = "Guilherme34/uncensor" |
|
uncensor_ds = load_dataset(uncensor_ds_name, split="train") |
|
def uncensor_items(): |
|
for row in uncensor_ds: |
|
for message in row["messages"]: |
|
if message["role"] == "user": |
|
yield message["content"] |
|
break |
|
|
|
extractor = Extractor( |
|
teacher="hf:mlabonne/Llama-3.1-70B-Instruct-lorablated", |
|
max_concurrent=8, |
|
output_dir=output_dir, |
|
client_opts=ClientOpts( |
|
base_url="https://glhf.chat/api/openai/v1", |
|
api_key=os.environ["GLHF_API_KEY"], |
|
), |
|
dataset=Dataset( |
|
train=[ |
|
Prompts( |
|
output_path=f"hub/{uncensor_ds_name}.jsonl", |
|
count=lambda: len(uncensor_ds), |
|
items=uncensor_items, |
|
), |
|
hub_prompts( |
|
name="mlabonne/harmful_behaviors", |
|
text_field="text", |
|
split=HubSplit(name="train"), |
|
), |
|
], |
|
eval=[ |
|
hub_prompts( |
|
name="mlabonne/harmful_behaviors", |
|
text_field="text", |
|
split=HubSplit(name="test"), |
|
), |
|
], |
|
), |
|
) |
|
|
|
lora_settings = LoraSettings( |
|
lora_r=32, |
|
lora_alpha=16, |
|
lora_dropout=0.01, |
|
num_epochs=2, |
|
learning_rate=4e-4, |
|
warmup_steps=10, |
|
) |
|
axolotl_config = lora_settings.llama_70b_axolotl(extractor.output_dataset()) |
|
|
|
extractor.run() |
|
axolotl_config.save(output_dir) |
|
|