grahamwhiteuk commited on
Commit
03c2ae6
·
1 Parent(s): 0a8d079

feat: advanced settings

Browse files

Signed-off-by: Graham White <[email protected]>

Files changed (1) hide show
  1. src/app.py +32 -7
src/app.py CHANGED
@@ -46,7 +46,15 @@ tokenizer.use_default_system_prompt = False
46
 
47
 
48
  @spaces.GPU
49
- def generate(message: str, chat_history: list[dict]) -> Iterator[str]:
 
 
 
 
 
 
 
 
50
  """Generate function for chat demo."""
51
  # Build messages
52
  conversation = []
@@ -60,7 +68,7 @@ def generate(message: str, chat_history: list[dict]) -> Iterator[str]:
60
  return_tensors="pt",
61
  add_generation_prompt=True,
62
  truncation=True,
63
- max_length=MAX_INPUT_TOKEN_LENGTH,
64
  )
65
 
66
  input_ids = input_ids.to(model.device)
@@ -68,13 +76,13 @@ def generate(message: str, chat_history: list[dict]) -> Iterator[str]:
68
  generate_kwargs = dict(
69
  {"input_ids": input_ids},
70
  streamer=streamer,
71
- max_new_tokens=MAX_NEW_TOKENS,
72
  do_sample=True,
73
- top_p=TOP_P,
74
- top_k=TOP_K,
75
- temperature=TEMPERATURE,
76
  num_beams=1,
77
- repetition_penalty=REPETITION_PENALTY,
78
  )
79
 
80
  t = Thread(target=model.generate, kwargs=generate_kwargs)
@@ -89,6 +97,15 @@ def generate(message: str, chat_history: list[dict]) -> Iterator[str]:
89
  css_file_path = Path(Path(__file__).parent / "app.css")
90
  head_file_path = Path(Path(__file__).parent / "app_head.html")
91
 
 
 
 
 
 
 
 
 
 
92
 
93
  with gr.Blocks(
94
  fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=carbon_theme, title=TITLE
@@ -108,6 +125,14 @@ with gr.Blocks(
108
  ],
109
  cache_examples=False,
110
  type="messages",
 
 
 
 
 
 
 
 
111
  )
112
 
113
  if __name__ == "__main__":
 
46
 
47
 
48
  @spaces.GPU
49
+ def generate(
50
+ message: str,
51
+ chat_history: list[dict],
52
+ temperature: float = TEMPERATURE,
53
+ top_p: float = TOP_P,
54
+ top_k: float = TOP_K,
55
+ repetition_penalty: float = REPETITION_PENALTY,
56
+ max_new_tokens: int = MAX_NEW_TOKENS,
57
+ ) -> Iterator[str]:
58
  """Generate function for chat demo."""
59
  # Build messages
60
  conversation = []
 
68
  return_tensors="pt",
69
  add_generation_prompt=True,
70
  truncation=True,
71
+ max_length=MAX_INPUT_TOKEN_LENGTH - max_new_tokens,
72
  )
73
 
74
  input_ids = input_ids.to(model.device)
 
76
  generate_kwargs = dict(
77
  {"input_ids": input_ids},
78
  streamer=streamer,
79
+ max_new_tokens=max_new_tokens,
80
  do_sample=True,
81
+ top_p=top_p,
82
+ top_k=top_k,
83
+ temperature=temperature,
84
  num_beams=1,
85
+ repetition_penalty=repetition_penalty,
86
  )
87
 
88
  t = Thread(target=model.generate, kwargs=generate_kwargs)
 
97
  css_file_path = Path(Path(__file__).parent / "app.css")
98
  head_file_path = Path(Path(__file__).parent / "app_head.html")
99
 
100
+ # advanced settings (displayed in Accordion)
101
+ temperature_slider = gr.Slider(minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature")
102
+ top_p_slider = gr.Slider(minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P")
103
+ top_k_slider = gr.Slider(minimum=0, maximum=100, value=TOP_K, step=1, label="Top K")
104
+ repetition_penalty_slider = gr.Slider(
105
+ minimum=0, maximum=2.0, value=REPETITION_PENALTY, step=0.1, label="Repetition Penalty"
106
+ )
107
+ max_new_tokens_slider = gr.Slider(minimum=1, maximum=2000, value=MAX_NEW_TOKENS, step=1, label="Max New Tokens")
108
+ chat_interface_accordion = gr.Accordion(label="Advanced Settings", open=False)
109
 
110
  with gr.Blocks(
111
  fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=carbon_theme, title=TITLE
 
125
  ],
126
  cache_examples=False,
127
  type="messages",
128
+ additional_inputs=[
129
+ temperature_slider,
130
+ top_p_slider,
131
+ top_k_slider,
132
+ repetition_penalty_slider,
133
+ max_new_tokens_slider,
134
+ ],
135
+ additional_inputs_accordion=chat_interface_accordion,
136
  )
137
 
138
  if __name__ == "__main__":