Spaces:
Sleeping
Sleeping
Commit
·
4343565
1
Parent(s):
a134a9d
more settings
Browse files- demo_watermark.py +9 -5
demo_watermark.py
CHANGED
@@ -261,7 +261,7 @@ def detect(input_text, args, device=None, tokenizer=None):
|
|
261 |
|
262 |
def run_gradio(args, model=None, device=None, tokenizer=None):
|
263 |
|
264 |
-
generate_partial = partial(generate, model=model, device=
|
265 |
detect_partial = partial(detect, device=device, tokenizer=tokenizer)
|
266 |
|
267 |
with gr.Blocks() as demo:
|
@@ -289,11 +289,13 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
289 |
generation_seed = gr.Number(label="Generation Seed",value=args.generation_seed, interactive=True)
|
290 |
with gr.Row():
|
291 |
n_beams = gr.Dropdown(label="Number of Beams",choices=list(range(1,11,1)), value=args.n_beams, visible=(not args.use_sampling))
|
|
|
|
|
292 |
|
293 |
with gr.Column(scale=1):
|
294 |
gr.Markdown(f"#### Watermarking Parameters")
|
295 |
with gr.Row():
|
296 |
-
gamma = gr.Slider(label="gamma",minimum=0.1, maximum=0.9, step=0.
|
297 |
with gr.Row():
|
298 |
delta = gr.Slider(label="delta",minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
|
299 |
with gr.Row():
|
@@ -326,6 +328,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
326 |
elif value == "greedy":
|
327 |
return gr.update(visible=True)
|
328 |
def update_n_beams(session_state, value): session_state.n_beams = int(value); return session_state
|
|
|
329 |
def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state
|
330 |
def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
|
331 |
|
@@ -337,6 +340,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
337 |
sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
|
338 |
generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
|
339 |
n_beams.change(update_n_beams,inputs=[session_args, n_beams], outputs=[session_args])
|
|
|
340 |
|
341 |
gamma.change(update_gamma,inputs=[session_args, gamma], outputs=[session_args])
|
342 |
delta.change(update_delta,inputs=[session_args, delta], outputs=[session_args])
|
@@ -365,7 +369,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
365 |
truncation_warning = gr.Number(visible=False)
|
366 |
def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
|
367 |
if truncation_warning:
|
368 |
-
return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]"
|
369 |
else:
|
370 |
return orig_prompt, args
|
371 |
|
@@ -412,7 +416,7 @@ def main(args):
|
|
412 |
if not args.skip_model_load:
|
413 |
model, tokenizer, device = load_model(args)
|
414 |
else:
|
415 |
-
model, tokenizer, device = None, None,
|
416 |
|
417 |
# Generate and detect, report to stdout
|
418 |
if not args.skip_model_load:
|
@@ -442,7 +446,7 @@ def main(args):
|
|
442 |
input_text = "In this work, we study watermarking of language model output. A watermark is a hidden pattern in text that is imperceptible to humans, while making the text algorithmically identifiable as synthetic. We propose an efficient watermark that makes synthetic text detectable from short spans of tokens (as few as 25 words), while false-positives (where human text is marked as machine-generated) are statistically improbable. The watermark detection algorithm can be made public, enabling third parties (e.g., social media platforms) to run it themselves, or it can be kept private and run behind an API. We seek a watermark with the following properties:\n"
|
443 |
|
444 |
|
445 |
-
term_width =
|
446 |
print("#"*term_width)
|
447 |
print("Prompt:")
|
448 |
print(input_text)
|
|
|
261 |
|
262 |
def run_gradio(args, model=None, device=None, tokenizer=None):
|
263 |
|
264 |
+
generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
|
265 |
detect_partial = partial(detect, device=device, tokenizer=tokenizer)
|
266 |
|
267 |
with gr.Blocks() as demo:
|
|
|
289 |
generation_seed = gr.Number(label="Generation Seed",value=args.generation_seed, interactive=True)
|
290 |
with gr.Row():
|
291 |
n_beams = gr.Dropdown(label="Number of Beams",choices=list(range(1,11,1)), value=args.n_beams, visible=(not args.use_sampling))
|
292 |
+
with gr.Row():
|
293 |
+
max_new_tokens = gr.Slider(label="Max Generated Tokens", minimum=10, maximum=1000, step=10, value=args.max_new_tokens)
|
294 |
|
295 |
with gr.Column(scale=1):
|
296 |
gr.Markdown(f"#### Watermarking Parameters")
|
297 |
with gr.Row():
|
298 |
+
gamma = gr.Slider(label="gamma",minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
|
299 |
with gr.Row():
|
300 |
delta = gr.Slider(label="delta",minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
|
301 |
with gr.Row():
|
|
|
328 |
elif value == "greedy":
|
329 |
return gr.update(visible=True)
|
330 |
def update_n_beams(session_state, value): session_state.n_beams = int(value); return session_state
|
331 |
+
def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state
|
332 |
def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state
|
333 |
def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
|
334 |
|
|
|
340 |
sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
|
341 |
generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
|
342 |
n_beams.change(update_n_beams,inputs=[session_args, n_beams], outputs=[session_args])
|
343 |
+
max_new_tokens.change(update_max_new_tokens,inputs=[session_args, max_new_tokens], outputs=[session_args])
|
344 |
|
345 |
gamma.change(update_gamma,inputs=[session_args, gamma], outputs=[session_args])
|
346 |
delta.change(update_delta,inputs=[session_args, delta], outputs=[session_args])
|
|
|
369 |
truncation_warning = gr.Number(visible=False)
|
370 |
def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
|
371 |
if truncation_warning:
|
372 |
+
return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
|
373 |
else:
|
374 |
return orig_prompt, args
|
375 |
|
|
|
416 |
if not args.skip_model_load:
|
417 |
model, tokenizer, device = load_model(args)
|
418 |
else:
|
419 |
+
model, tokenizer, device = None, None, None
|
420 |
|
421 |
# Generate and detect, report to stdout
|
422 |
if not args.skip_model_load:
|
|
|
446 |
input_text = "In this work, we study watermarking of language model output. A watermark is a hidden pattern in text that is imperceptible to humans, while making the text algorithmically identifiable as synthetic. We propose an efficient watermark that makes synthetic text detectable from short spans of tokens (as few as 25 words), while false-positives (where human text is marked as machine-generated) are statistically improbable. The watermark detection algorithm can be made public, enabling third parties (e.g., social media platforms) to run it themselves, or it can be kept private and run behind an API. We seek a watermark with the following properties:\n"
|
447 |
|
448 |
|
449 |
+
term_width = 80
|
450 |
print("#"*term_width)
|
451 |
print("Prompt:")
|
452 |
print(input_text)
|