File size: 1,764 Bytes
71e47a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import functools
from pathlib import Path

import yaml


def default_preset():
    return {
        'do_sample': True,
        'temperature': 1,
        'top_p': 1,
        'top_k': 0,
        'typical_p': 1,
        'epsilon_cutoff': 0,
        'eta_cutoff': 0,
        'tfs': 1,
        'top_a': 0,
        'repetition_penalty': 1,
        'repetition_penalty_range': 0,
        'encoder_repetition_penalty': 1,
        'no_repeat_ngram_size': 0,
        'min_length': 0,
        'guidance_scale': 1,
        'mirostat_mode': 0,
        'mirostat_tau': 5.0,
        'mirostat_eta': 0.1,
        'penalty_alpha': 0,
        'num_beams': 1,
        'length_penalty': 1,
        'early_stopping': False,
        'custom_token_bans': '',
    }


def presets_params():
    return [k for k in default_preset()]


def load_preset(name):
    generate_params = default_preset()
    if name not in ['None', None, '']:
        with open(Path(f'presets/{name}.yaml'), 'r') as infile:
            preset = yaml.safe_load(infile)

        for k in preset:
            generate_params[k] = preset[k]

    generate_params['temperature'] = min(1.99, generate_params['temperature'])
    return generate_params


@functools.cache
def load_preset_memoized(name):
    return load_preset(name)


def load_preset_for_ui(name, state):
    generate_params = load_preset(name)
    state.update(generate_params)
    return state, *[generate_params[k] for k in presets_params()]


def generate_preset_yaml(state):
    defaults = default_preset()
    data = {k: state[k] for k in presets_params()}

    # Remove entries that are identical to the defaults
    for k in list(data.keys()):
        if data[k] == defaults[k]:
            del data[k]

    return yaml.dump(data, sort_keys=False)