distilgpt2-multiprompt

Generate/augment your prompt with a model trained on a large & diverse prompt dataset.

This model is a fine-tuned version of distilgpt2 on the pszemraj/text2image-prompts-multi dataset. It achieves the following results on the evaluation set:

  • Loss: 2.0213
  • perplexity = 7.55

Intended uses & limitations

  • The model will generate augmentations that are biased towards the training data, i.e. what people already asked for in the SD/midjourney discords, etc. Creating a larger dataset was an attempt at mitigating this through more data from different datasets.

Training and evaluation data

See the pszemraj/text2image-prompts-multi dataset card for details. The dataset is a compilation of several text-to-image prompt datasets on huggingface :)

Training procedure

  • this was trained with several training rounds, 8 epochs in total on the train set.

Training hyperparameters (last training round)

The following hyperparameters were used during training:

  • learning_rate: 0.0006
  • train_batch_size: 16
  • eval_batch_size: 4
  • seed: 42
  • distributed_type: multi-GPU
  • num_devices: 2
  • gradient_accumulation_steps: 8
  • total_train_batch_size: 256
  • total_eval_batch_size: 8
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: cosine
  • lr_scheduler_warmup_ratio: 0.01
  • num_epochs: 2.0

Training results

Training Loss Epoch Step Validation Loss
2.1637 1.0 965 2.0581
2.0885 2.0 1930 2.0213

Framework versions

  • Transformers 4.25.0.dev0
  • Pytorch 1.13.0+cu117
  • Datasets 2.6.1
  • Tokenizers 0.13.1
Downloads last month
34
Safetensors
Model size
88.2M params
Tensor type
F32
·
U8
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train pszemraj/distilgpt2-multiprompt