NULLNode commited on
Commit
893a1df
1 Parent(s): c935af6

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +52 -0
  2. demo_watermark.py +1772 -0
  3. homoglyphs.py +265 -0
  4. normalizers.py +195 -0
  5. requirements.txt +7 -0
  6. watermark_processor.py +280 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
3
+ # available at https://arxiv.org/abs/2301.10226
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from argparse import Namespace
18
+ args = Namespace()
19
+
20
+ arg_dict = {
21
+ 'run_gradio': True,
22
+ 'demo_public': False,
23
+ # 'model_name_or_path': 'facebook/opt-125m',
24
+ # 'model_name_or_path': 'facebook/opt-1.3b',
25
+ # 'model_name_or_path': 'facebook/opt-2.7b',
26
+ 'model_name_or_path': 'facebook/opt-6.7b',
27
+ # 'model_name_or_path': 'facebook/opt-13b',
28
+ # 'load_fp16' : True,
29
+ 'load_fp16' : False,
30
+ 'prompt_max_length': None,
31
+ 'max_new_tokens': 200,
32
+ 'generation_seed': 123,
33
+ 'use_sampling': True,
34
+ 'n_beams': 1,
35
+ 'sampling_temp': 0.7,
36
+ 'use_gpu': True,
37
+ 'seeding_scheme': 'simple_1',
38
+ 'gamma': 0.25,
39
+ 'delta': 2.0,
40
+ 'normalizers': '',
41
+ 'ignore_repeated_bigrams': False,
42
+ 'detection_z_threshold': 4.0,
43
+ 'select_green_tokens': True,
44
+ 'skip_model_load': True,
45
+ 'seed_separately': True,
46
+ }
47
+
48
+ args.__dict__.update(arg_dict)
49
+
50
+ from demo_watermark import main
51
+
52
+ main(args)
demo_watermark.py ADDED
@@ -0,0 +1,1772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # coding=utf-8
2
+ # # Copyright 2023 Authors of "A Watermark for Large Language Models"
3
+ # # available at https://arxiv.org/abs/2301.10226
4
+ # #
5
+ # # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # # you may not use this file except in compliance with the License.
7
+ # # You may obtain a copy of the License at
8
+ # #
9
+ # # http://www.apache.org/licenses/LICENSE-2.0
10
+ # #
11
+ # # Unless required by applicable law or agreed to in writing, software
12
+ # # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # # See the License for the specific language governing permissions and
15
+ # # limitations under the License.
16
+ # from __future__ import annotations
17
+ # import os
18
+ # import argparse
19
+ # from argparse import Namespace
20
+ # from pprint import pprint
21
+ # from functools import partial
22
+ #
23
+ # import numpy # for gradio hot reload
24
+ # import gradio as gr
25
+ #
26
+ # import torch
27
+ #
28
+ # from transformers import (AutoTokenizer,
29
+ # AutoModelForSeq2SeqLM,
30
+ # AutoModelForCausalLM,
31
+ # LogitsProcessorList)
32
+ #
33
+ # from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
34
+ #
35
+ # from typing import Iterable
36
+ # from gradio.themes.base import Base
37
+ # from gradio.themes.utils import colors, fonts, sizes
38
+ # import time
39
+ #
40
+ # def str2bool(v):
41
+ # """Util function for user friendly boolean flag args"""
42
+ # if isinstance(v, bool):
43
+ # return v
44
+ # if v.lower() in ('yes', 'true', 't', 'y', '1'):
45
+ # return True
46
+ # elif v.lower() in ('no', 'false', 'f', 'n', '0'):
47
+ # return False
48
+ # else:
49
+ # raise argparse.ArgumentTypeError('Boolean value expected.')
50
+ #
51
+ # def parse_args():
52
+ # """Command line argument specification"""
53
+ #
54
+ # parser = argparse.ArgumentParser(description="A minimum working example of applying the watermark to any LLM that supports the huggingface 🤗 `generate` API")
55
+ #
56
+ # parser.add_argument(
57
+ # "--run_gradio",
58
+ # type=str2bool,
59
+ # default=True,
60
+ # help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.",
61
+ # )
62
+ # parser.add_argument(
63
+ # "--demo_public",
64
+ # type=str2bool,
65
+ # default=False,
66
+ # help="Whether to expose the gradio demo to the internet.",
67
+ # )
68
+ # parser.add_argument(
69
+ # "--model_name_or_path",
70
+ # type=str,
71
+ # default="facebook/opt-6.7b",
72
+ # help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
73
+ # )
74
+ # parser.add_argument(
75
+ # "--prompt_max_length",
76
+ # type=int,
77
+ # default=None,
78
+ # help="Truncation length for prompt, overrides model config's max length field.",
79
+ # )
80
+ # parser.add_argument(
81
+ # "--max_new_tokens",
82
+ # type=int,
83
+ # default=200,
84
+ # help="Maximmum number of new tokens to generate.",
85
+ # )
86
+ # parser.add_argument(
87
+ # "--generation_seed",
88
+ # type=int,
89
+ # default=123,
90
+ # help="Seed for setting the torch global rng prior to generation.",
91
+ # )
92
+ # parser.add_argument(
93
+ # "--use_sampling",
94
+ # type=str2bool,
95
+ # default=True,
96
+ # help="Whether to generate using multinomial sampling.",
97
+ # )
98
+ # parser.add_argument(
99
+ # "--sampling_temp",
100
+ # type=float,
101
+ # default=0.7,
102
+ # help="Sampling temperature to use when generating using multinomial sampling.",
103
+ # )
104
+ # parser.add_argument(
105
+ # "--n_beams",
106
+ # type=int,
107
+ # default=1,
108
+ # help="Number of beams to use for beam search. 1 is normal greedy decoding",
109
+ # )
110
+ # parser.add_argument(
111
+ # "--use_gpu",
112
+ # type=str2bool,
113
+ # default=True,
114
+ # help="Whether to run inference and watermark hashing/seeding/permutation on gpu.",
115
+ # )
116
+ # parser.add_argument(
117
+ # "--seeding_scheme",
118
+ # type=str,
119
+ # default="simple_1",
120
+ # help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
121
+ # )
122
+ # parser.add_argument(
123
+ # "--gamma",
124
+ # type=float,
125
+ # default=0.25,
126
+ # help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
127
+ # )
128
+ # parser.add_argument(
129
+ # "--delta",
130
+ # type=float,
131
+ # default=2.0,
132
+ # help="The amount/bias to add to each of the greenlist token logits before each token sampling step.",
133
+ # )
134
+ # parser.add_argument(
135
+ # "--normalizers",
136
+ # type=str,
137
+ # default="",
138
+ # help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.",
139
+ # )
140
+ # parser.add_argument(
141
+ # "--ignore_repeated_bigrams",
142
+ # type=str2bool,
143
+ # default=False,
144
+ # help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.",
145
+ # )
146
+ # parser.add_argument(
147
+ # "--detection_z_threshold",
148
+ # type=float,
149
+ # default=4.0,
150
+ # help="The test statistic threshold for the detection hypothesis test.",
151
+ # )
152
+ # parser.add_argument(
153
+ # "--select_green_tokens",
154
+ # type=str2bool,
155
+ # default=True,
156
+ # help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.",
157
+ # )
158
+ # parser.add_argument(
159
+ # "--skip_model_load",
160
+ # type=str2bool,
161
+ # default=False,
162
+ # help="Skip the model loading to debug the interface.",
163
+ # )
164
+ # parser.add_argument(
165
+ # "--seed_separately",
166
+ # type=str2bool,
167
+ # default=True,
168
+ # help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.",
169
+ # )
170
+ # parser.add_argument(
171
+ # "--load_fp16",
172
+ # type=str2bool,
173
+ # default=False,
174
+ # help="Whether to run model in float16 precsion.",
175
+ # )
176
+ # args = parser.parse_args()
177
+ # return args
178
+ #
179
+ # def load_model(args):
180
+ # """Load and return the model and tokenizer"""
181
+ #
182
+ # args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5","T0"]])
183
+ # args.is_decoder_only_model = any([(model_type in args.model_name_or_path) for model_type in ["gpt","opt","bloom"]])
184
+ # if args.is_seq2seq_model:
185
+ # model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
186
+ # elif args.is_decoder_only_model:
187
+ # if args.load_fp16:
188
+ # model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16, device_map='auto')
189
+ # else:
190
+ # model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
191
+ # else:
192
+ # raise ValueError(f"Unknown model type: {args.model_name_or_path}")
193
+ #
194
+ # if args.use_gpu:
195
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
196
+ # if args.load_fp16:
197
+ # pass
198
+ # else:
199
+ # model = model.to(device)
200
+ # else:
201
+ # device = "cpu"
202
+ # model.eval()
203
+ #
204
+ # tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
205
+ #
206
+ # return model, tokenizer, device
207
+ #
208
+ # def generate(prompt, args, model=None, device=None, tokenizer=None):
209
+ # """Instatiate the WatermarkLogitsProcessor according to the watermark parameters
210
+ # and generate watermarked text by passing it to the generate method of the model
211
+ # as a logits processor. """
212
+ #
213
+ # print(f"Generating with {args}")
214
+ #
215
+ # watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
216
+ # gamma=args.gamma,
217
+ # delta=args.delta,
218
+ # seeding_scheme=args.seeding_scheme,
219
+ # select_green_tokens=args.select_green_tokens)
220
+ #
221
+ # gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
222
+ #
223
+ # if args.use_sampling:
224
+ # gen_kwargs.update(dict(
225
+ # do_sample=True,
226
+ # top_k=0,
227
+ # temperature=args.sampling_temp
228
+ # ))
229
+ # else:
230
+ # gen_kwargs.update(dict(
231
+ # num_beams=args.n_beams
232
+ # ))
233
+ #
234
+ # generate_without_watermark = partial(
235
+ # model.generate,
236
+ # **gen_kwargs
237
+ # )
238
+ # generate_with_watermark = partial(
239
+ # model.generate,
240
+ # logits_processor=LogitsProcessorList([watermark_processor]),
241
+ # **gen_kwargs
242
+ # )
243
+ # if args.prompt_max_length:
244
+ # pass
245
+ # elif hasattr(model.config,"max_position_embedding"):
246
+ # args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens
247
+ # else:
248
+ # args.prompt_max_length = 2048-args.max_new_tokens
249
+ #
250
+ # tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device)
251
+ # truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
252
+ # redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
253
+ #
254
+ # torch.manual_seed(args.generation_seed)
255
+ # output_without_watermark = generate_without_watermark(**tokd_input)
256
+ #
257
+ # # optional to seed before second generation, but will not be the same again generally, unless delta==0.0, no-op watermark
258
+ # if args.seed_separately:
259
+ # torch.manual_seed(args.generation_seed)
260
+ # output_with_watermark = generate_with_watermark(**tokd_input)
261
+ #
262
+ # if args.is_decoder_only_model:
263
+ # # need to isolate the newly generated tokens
264
+ # output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:]
265
+ # output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:]
266
+ #
267
+ # decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
268
+ # decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
269
+ #
270
+ # return (redecoded_input,
271
+ # int(truncation_warning),
272
+ # decoded_output_without_watermark,
273
+ # decoded_output_with_watermark,
274
+ # args)
275
+ # # decoded_output_with_watermark)
276
+ #
277
+ # def format_names(s):
278
+ # """Format names for the gradio demo interface"""
279
+ # s=s.replace("num_tokens_scored","Tokens Counted (T)")
280
+ # s=s.replace("num_green_tokens","# Tokens in Greenlist")
281
+ # s=s.replace("green_fraction","Fraction of T in Greenlist")
282
+ # s=s.replace("z_score","z-score")
283
+ # s=s.replace("p_value","p value")
284
+ # s=s.replace("prediction","Prediction")
285
+ # s=s.replace("confidence","Confidence")
286
+ # return s
287
+ #
288
+ # def list_format_scores(score_dict, detection_threshold):
289
+ # """Format the detection metrics into a gradio dataframe input format"""
290
+ # lst_2d = []
291
+ # # lst_2d.append(["z-score threshold", f"{detection_threshold}"])
292
+ # for k,v in score_dict.items():
293
+ # if k=='green_fraction':
294
+ # lst_2d.append([format_names(k), f"{v:.1%}"])
295
+ # elif k=='confidence':
296
+ # lst_2d.append([format_names(k), f"{v:.3%}"])
297
+ # elif isinstance(v, float):
298
+ # lst_2d.append([format_names(k), f"{v:.3g}"])
299
+ # elif isinstance(v, bool):
300
+ # lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")])
301
+ # else:
302
+ # lst_2d.append([format_names(k), f"{v}"])
303
+ # if "confidence" in score_dict:
304
+ # lst_2d.insert(-2,["z-score Threshold", f"{detection_threshold}"])
305
+ # else:
306
+ # lst_2d.insert(-1,["z-score Threshold", f"{detection_threshold}"])
307
+ # return lst_2d
308
+ #
309
+ # def detect(input_text, args, device=None, tokenizer=None):
310
+ # """Instantiate the WatermarkDetection object and call detect on
311
+ # the input text returning the scores and outcome of the test"""
312
+ # watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
313
+ # gamma=args.gamma,
314
+ # seeding_scheme=args.seeding_scheme,
315
+ # device=device,
316
+ # tokenizer=tokenizer,
317
+ # z_threshold=args.detection_z_threshold,
318
+ # normalizers=args.normalizers,
319
+ # ignore_repeated_bigrams=args.ignore_repeated_bigrams,
320
+ # select_green_tokens=args.select_green_tokens)
321
+ # if len(input_text)-1 > watermark_detector.min_prefix_len:
322
+ # score_dict = watermark_detector.detect(input_text)
323
+ # # output = str_format_scores(score_dict, watermark_detector.z_threshold)
324
+ # output = list_format_scores(score_dict, watermark_detector.z_threshold)
325
+ # else:
326
+ # # output = (f"Error: string not long enough to compute watermark presence.")
327
+ # output = [["Error","string too short to compute metrics"]]
328
+ # output += [["",""] for _ in range(6)]
329
+ # return output, args
330
+ #
331
+ # class Seafoam(Base):
332
+ # def __init__(
333
+ # self,
334
+ # *,
335
+ # primary_hue: colors.Color | str = colors.emerald,
336
+ # secondary_hue: colors.Color | str = colors.blue,
337
+ # neutral_hue: colors.Color | str = colors.blue,
338
+ # spacing_size: sizes.Size | str = sizes.spacing_md,
339
+ # radius_size: sizes.Size | str = sizes.radius_md,
340
+ # text_size: sizes.Size | str = sizes.text_lg,
341
+ # font: fonts.Font
342
+ # | str
343
+ # | Iterable[fonts.Font | str] = (
344
+ # fonts.GoogleFont("Quicksand"),
345
+ # "ui-sans-serif",
346
+ # "sans-serif",
347
+ # ),
348
+ # font_mono: fonts.Font
349
+ # | str
350
+ # | Iterable[fonts.Font | str] = (
351
+ # fonts.GoogleFont("IBM Plex Mono"),
352
+ # "ui-monospace",
353
+ # "monospace",
354
+ # ),
355
+ # ):
356
+ # super().__init__(
357
+ # primary_hue=primary_hue,
358
+ # secondary_hue=secondary_hue,
359
+ # neutral_hue=neutral_hue,
360
+ # spacing_size=spacing_size,
361
+ # radius_size=radius_size,
362
+ # text_size=text_size,
363
+ # font=font,
364
+ # font_mono=font_mono,
365
+ # )
366
+ # super().set(
367
+ # body_background_fill="repeating-linear-gradient(45deg, *primary_200, *primary_200 10px, *primary_50 10px, *primary_50 20px)",
368
+ # body_background_fill_dark="repeating-linear-gradient(45deg, *primary_800, *primary_800 10px, *primary_900 10px, *primary_900 20px)",
369
+ # button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
370
+ # button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
371
+ # button_primary_text_color="white",
372
+ # button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
373
+ # slider_color="*secondary_300",
374
+ # slider_color_dark="*secondary_600",
375
+ # block_title_text_weight="600",
376
+ # block_border_width="3px",
377
+ # block_shadow="*shadow_drop_lg",
378
+ # button_shadow="*shadow_drop_lg",
379
+ # button_large_padding="32px",
380
+ # )
381
+ #
382
+ # seafoam = Seafoam()
383
+ #
384
+ # def run_gradio(args, model=None, device=None, tokenizer=None):
385
+ # """Define and launch the gradio demo interface"""
386
+ # generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
387
+ # detect_partial = partial(detect, device=device, tokenizer=tokenizer)
388
+ #
389
+ # # with gr.Blocks(theme="shivi/calm_seafoam") as demo:
390
+ # # with gr.Blocks(theme="finlaymacklon/smooth_slate") as demo:
391
+ # # with gr.Blocks(theme="freddyaboulton/test-blue") as demo:
392
+ # with gr.Blocks(theme="xiaobaiyuan/theme_brief") as demo:
393
+ # gr.Markdown(
394
+ # """
395
+ # # 💧 大语言模型水印 🔍
396
+ # """
397
+ # )
398
+ #
399
+ # with gr.Accordion("参数说明", open=False):
400
+ # gr.Markdown(
401
+ # """
402
+ # - `z分数阈值` : 假设检验的截断值。
403
+ # - `标记个数 (T)` : 检测算法计算的输出中计数的标记数。
404
+ # 在简单的单个标记种子方案中,第一个标记被省略,因为它没有前缀标记,无法为其生成绿色列表。
405
+ # 在底部面板中描述的“忽略重复二元组”检测算法下,如果存在大量重复,这个数量可能远小于生成的总标记数。
406
+ # - `绿色列表中的标记数目` : 观察到的落在各自绿色列表中的标记数。
407
+ # - `T中含有绿色列表标记的比例` : `绿色列表中的标记数目` / `T`。预期对于人类/非水印文本,这个比例大约等于 gamma。
408
+ # - `z分数` : 检测假设检验的检验统计量。如果大于 `z分数阈值`,则“拒绝零假设”,即文本是人类/非水印的,推断它是带有水印的。
409
+ # - `p值` : 在零假设下观察到计算的 `z-分数` 的概率。
410
+ # 这是在不知道水印程序/绿色列表的情况下观察到 'T中含有绿色列表标记的比例' 的概率。
411
+ # 如果这个值非常小,我们有信心认为这么多绿色标记不是随机选择的。
412
+ # - `预测` : 假设检验的结果,即观察到的 `z分数` 是否高于 `z分数阈值`。
413
+ # - `置信度` : 如果我们拒绝零假设,并且 `预测` 是“Watermarked”,那么我们报告 1-`p 值` 来表示基于这个 `z分数` 观察的检测置信度的不可能性。
414
+ # """
415
+ # )
416
+ #
417
+ # with gr.Accordion("关于模型能力的说明", open=True):
418
+ # gr.Markdown(
419
+ # """
420
+ # 本演示使用适用于单个 GPU 的开源语言模型。这些模型比专有商业工具(如 ChatGPT、Claude 或 Bard)的能力更弱。
421
+ #
422
+ # 还有一件事,我们使用语言模型旨在“完成”您的提示,而不是经过微调以遵循指令的模型。
423
+ # 为了获得最佳结果,请使用一些组成段落开头的句子提示模型,然后让它“继续”您的段落。
424
+ # 一些示例包括维基百科文章的开头段落或故事的前几句话。
425
+ # 结尾处中断的较长提示将产生更流畅的生成。
426
+ # """
427
+ # )
428
+ #
429
+ # gr.Markdown(f"语言模型: {args.model_name_or_path} {'(float16 mode)' if args.load_fp16 else ''}")
430
+ #
431
+ # # Construct state for parameters, define updates and toggles
432
+ # default_prompt = args.__dict__.pop("default_prompt")
433
+ # session_args = gr.State(value=args)
434
+ #
435
+ # with gr.Tab("生成检测"):
436
+ # with gr.Row():
437
+ # prompt = gr.Textbox(label=f"提示词", interactive=True,lines=10,max_lines=10, value=default_prompt)
438
+ # with gr.Row():
439
+ # generate_btn = gr.Button("生成")
440
+ # with gr.Row():
441
+ # with gr.Column(scale=2):
442
+ # with gr.Tab("未嵌入水印输出的文本"):
443
+ # output_without_watermark = gr.Textbox(label=None, interactive=False, lines=14,
444
+ # max_lines=14, show_label=False)
445
+ # with gr.Tab("高亮"):
446
+ # highlight_output_without_watermark = gr.Textbox(label=None, interactive=False, lines=14,
447
+ # max_lines=14, show_label=False)
448
+ # with gr.Column(scale=1):
449
+ # # without_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
450
+ # without_watermark_detection_result = gr.Dataframe(headers=["参数", "值"], interactive=False,
451
+ # row_count=7, col_count=2)
452
+ #
453
+ #
454
+ # with gr.Row():
455
+ # with gr.Column(scale=2):
456
+ # with gr.Tab("嵌入了水印输出的文本"):
457
+ # output_with_watermark = gr.Textbox(label=None, interactive=False, lines=14,
458
+ # max_lines=14, show_label=False)
459
+ # with gr.Tab("高亮"):
460
+ # highlight_output_with_watermark = gr.Textbox(label=None, interactive=False, lines=14,
461
+ # max_lines=14, show_label=False)
462
+ # with gr.Column(scale=1):
463
+ # # with_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
464
+ # with_watermark_detection_result = gr.Dataframe(headers=["参数", "值"], interactive=False,
465
+ # row_count=7, col_count=2)
466
+ #
467
+ #
468
+ # redecoded_input = gr.Textbox(visible=False)
469
+ # truncation_warning = gr.Number(visible=False)
470
+ # def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
471
+ # if truncation_warning:
472
+ # return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
473
+ # else:
474
+ # return orig_prompt, args
475
+ #
476
+ # with gr.Tab("仅检测"):
477
+ # with gr.Row():
478
+ # with gr.Column(scale=2):
479
+ # detection_input = gr.Textbox(label="待分析文本", interactive=True, lines=14, max_lines=14)
480
+ # with gr.Column(scale=1):
481
+ # # detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
482
+ # detection_result = gr.Dataframe(headers=["参数", "值"], interactive=False, row_count=7, col_count=2)
483
+ # with gr.Row():
484
+ # detect_btn = gr.Button("检测")
485
+ #
486
+ # # Parameter selection group
487
+ # with gr.Accordion("高级设置", open=False):
488
+ # with gr.Row():
489
+ # with gr.Column(scale=1):
490
+ # gr.Markdown(f"#### 生成参数")
491
+ # with gr.Row():
492
+ # decoding = gr.Radio(label="解码方式", choices=["multinomial", "greedy"],
493
+ # value=("multinomial" if args.use_sampling else "greedy"))
494
+ #
495
+ # with gr.Row():
496
+ # sampling_temp = gr.Slider(label="采样随机性多样性权重", minimum=0.1, maximum=1.0, step=0.1,
497
+ # value=args.sampling_temp, visible=True)
498
+ # with gr.Row():
499
+ # generation_seed = gr.Number(label="生成种子", value=args.generation_seed, interactive=True)
500
+ # with gr.Row():
501
+ # n_beams = gr.Dropdown(label="束搜索路数", choices=list(range(1, 11, 1)), value=args.n_beams,
502
+ # visible=(not args.use_sampling))
503
+ # with gr.Row():
504
+ # max_new_tokens = gr.Slider(label="生成最大标记数", minimum=10, maximum=1000, step=10,
505
+ # value=args.max_new_tokens)
506
+ #
507
+ # with gr.Column(scale=1):
508
+ # gr.Markdown(f"#### 水印参数")
509
+ # with gr.Row():
510
+ # gamma = gr.Slider(label="gamma", minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
511
+ # with gr.Row():
512
+ # delta = gr.Slider(label="delta", minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
513
+ # gr.Markdown(f"#### 检测参数")
514
+ # with gr.Row():
515
+ # detection_z_threshold = gr.Slider(label="z-score 阈值", minimum=0.0, maximum=10.0, step=0.1,
516
+ # value=args.detection_z_threshold)
517
+ # with gr.Row():
518
+ # ignore_repeated_bigrams = gr.Checkbox(label="忽略重复 Bigram")
519
+ # with gr.Row():
520
+ # normalizers = gr.CheckboxGroup(label="正则化器",
521
+ # choices=["unicode", "homoglyphs", "truecase"],
522
+ # value=args.normalizers)
523
+ # # with gr.Accordion("Actual submitted parameters:",open=False):
524
+ # with gr.Row():
525
+ # gr.Markdown(
526
+ # f"_提示: 滑块更新有延迟。点击滑动条或使用右侧的数字窗口可以帮助更新。下方窗口显示当前的设置。_")
527
+ # with gr.Row():
528
+ # current_parameters = gr.Textbox(label="当前参数", value=args, interactive=False, lines=6)
529
+ # with gr.Accordion("保留设置", open=False):
530
+ # with gr.Row():
531
+ # with gr.Column(scale=1):
532
+ # seed_separately = gr.Checkbox(label="红绿分别生成", value=args.seed_separately)
533
+ # with gr.Column(scale=1):
534
+ # select_green_tokens = gr.Checkbox(label="从分区中选择'greenlist'",
535
+ # value=args.select_green_tokens)
536
+ #
537
+ # with gr.Accordion("关于设置", open=False):
538
+ # gr.Markdown(
539
+ # """
540
+ # #### 生成参数:
541
+ #
542
+ # - 解码方法:我们可以使用多项式采样或贪婪解码来从模型中生成标记。
543
+ # - 采样温度:如果使用多项式采样,可以设置采样分布的温度。
544
+ # 0.0 相当于贪婪解码,而 1.0 是下一个标记分布中的最大变异性/熵。
545
+ # 0.7 在保持对模型对前几个候选者的估计准确性的同时增加了多样性。对于贪婪解码无效。
546
+ # - 生成种子:在运行生成之前传递给 torch 随机数生成器的整数。使多项式采样策略输出可复现。对于贪婪解码无效。
547
+ # - 并行数:当使用贪婪解码时,还可以将并行数设置为 > 1 以启用波束搜索。
548
+ # 这在多项式采样中未实现/排除在论文中,但可能会在未来添加。
549
+ # - 最大生成标记数:传递给生成方法的 `max_new_tokens` 参数,以在特定数量的新标记处停止输出。
550
+ # 请注意,根据提示,模型可以生成较少的标记。
551
+ # 这将隐含地将可能的提示标记数量设置为模型的最大输入长度减去 `max_new_tokens`,
552
+ # 并且输入将相应地被截断。
553
+ #
554
+ # #### 水印参数:
555
+ #
556
+ # - gamma:每次生成步骤将词汇表分成绿色列表的部分。较小的 gamma 值通过使得有水印的模型能够更好地与人类/无水印文本区分,
557
+ # 从而创建了更强的水印,因为它会更倾向于从较小的绿色集合中进行采样,使得这些标记不太可能是偶然发生的。
558
+ # - delta:在每个生成步骤中,在采样/选择下一个标记之前,为绿色列表中的每个标记的对数概率添加正偏差。
559
+ # 较高的 delta 值意味着绿色列表标记更受有水印的模型青睐,并且随着偏差的增大,水印从“软性”过渡到“硬性”。
560
+ # 对于硬性水印,几乎所有的标记都是绿色的,但这可能对生成质量产生不利影响,特别是当分布的灵活性有限时。
561
+ #
562
+ # #### 检测器参数:
563
+ #
564
+ # - z-score 阈值:假设检验的 z-score 截断值。较高的阈值(例如 4.0)使得预测人类/无水印文本是有水印的
565
+ # (_false positives_)的可能性非常低,因为一个真正的包含大量标记的人类文本几乎不可能达到那么高的 z-score。
566
+ # 较低的阈值将捕捉更多的真正有水印的文本,因为一些有水印的文本可能包含较少的绿色标记并获得较低的 z-score,
567
+ # 但仍然通过较低的门槛被标记为“有水印”。然而,较低的阈值会增加被错误地标记为有水印的具有略高于平均绿色标记数的人类文本的几率。
568
+ # 4.0-5.0 提供了极低的误报率,同时仍然准确地捕捉到大多数有水印的文本。
569
+ # - 忽略重复的双字母组合:此备用检测算法在检测期间只考虑文本中的唯一双字母组合,
570
+ # 根据每对中的第一个计算绿色列表,并检查第二个是否在列表内。
571
+ # 这意味着 `T` 现在是文本中唯一的双字母组合的数量,
572
+ # 如果文本包含大量重复,那么它将少于生成的总标记数。
573
+ # 有关更详细的讨论,请参阅论文。
574
+ # - 标准化:我们实现了一些基本的标准化,以防止文本在检测过程中受到各种对抗性扰动。
575
+ # 目前,我们支持将所有字符转换为 Unicode,使用规范形式替换同形字符,并标准化大小写。
576
+ # """
577
+ # )
578
+ #
579
+ # # gr.HTML("""
580
+ # # <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
581
+ # # Follow the github link at the top and host the demo on your own GPU hardware to test out larger models.
582
+ # # <br/>
583
+ # # <a href="https://huggingface.co/spaces/tomg-group-umd/lm-watermarking?duplicate=true">
584
+ # # <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
585
+ # # <p/>
586
+ # # """)
587
+ #
588
+ # # Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag
589
+ # generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark,session_args])
590
+ # # Show truncated version of prompt if truncation occurred
591
+ # redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
592
+ # # Call detection when the outputs (of the generate function) are updated
593
+ # output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
594
+ # output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
595
+ # # Register main detection tab click
596
+ # detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, session_args])
597
+ #
598
+ # # State management logic
599
+ # # update callbacks that change the state dict
600
+ # def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state
601
+ # def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state
602
+ # def update_gamma(session_state, value): session_state.gamma = float(value); return session_state
603
+ # def update_delta(session_state, value): session_state.delta = float(value); return session_state
604
+ # def update_detection_z_threshold(session_state, value): session_state.detection_z_threshold = float(value); return session_state
605
+ # def update_decoding(session_state, value):
606
+ # if value == "multinomial":
607
+ # session_state.use_sampling = True
608
+ # elif value == "greedy":
609
+ # session_state.use_sampling = False
610
+ # return session_state
611
+ # def toggle_sampling_vis(value):
612
+ # if value == "multinomial":
613
+ # return gr.update(visible=True)
614
+ # elif value == "greedy":
615
+ # return gr.update(visible=False)
616
+ # def toggle_sampling_vis_inv(value):
617
+ # if value == "multinomial":
618
+ # return gr.update(visible=False)
619
+ # elif value == "greedy":
620
+ # return gr.update(visible=True)
621
+ # def update_n_beams(session_state, value): session_state.n_beams = int(value); return session_state
622
+ # def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state
623
+ # def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state
624
+ # def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
625
+ # def update_seed_separately(session_state, value): session_state.seed_separately = value; return session_state
626
+ # def update_select_green_tokens(session_state, value): session_state.select_green_tokens = value; return session_state
627
+ # # registering callbacks for toggling the visibilty of certain parameters
628
+ # decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
629
+ # decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
630
+ # decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams])
631
+ # # registering all state update callbacks
632
+ # decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args])
633
+ # sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
634
+ # generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
635
+ # n_beams.change(update_n_beams,inputs=[session_args, n_beams], outputs=[session_args])
636
+ # max_new_tokens.change(update_max_new_tokens,inputs=[session_args, max_new_tokens], outputs=[session_args])
637
+ # gamma.change(update_gamma,inputs=[session_args, gamma], outputs=[session_args])
638
+ # delta.change(update_delta,inputs=[session_args, delta], outputs=[session_args])
639
+ # detection_z_threshold.change(update_detection_z_threshold,inputs=[session_args, detection_z_threshold], outputs=[session_args])
640
+ # ignore_repeated_bigrams.change(update_ignore_repeated_bigrams,inputs=[session_args, ignore_repeated_bigrams], outputs=[session_args])
641
+ # normalizers.change(update_normalizers,inputs=[session_args, normalizers], outputs=[session_args])
642
+ # seed_separately.change(update_seed_separately,inputs=[session_args, seed_separately], outputs=[session_args])
643
+ # select_green_tokens.change(update_select_green_tokens,inputs=[session_args, select_green_tokens], outputs=[session_args])
644
+ # # register additional callback on button clicks that updates the shown parameters window
645
+ # generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
646
+ # detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
647
+ # # When the parameters change, display the update and fire detection, since some detection params dont change the model output.
648
+ # gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
649
+ # gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
650
+ # gamma.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
651
+ # gamma.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
652
+ # detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
653
+ # detection_z_threshold.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
654
+ # detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
655
+ # detection_z_threshold.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
656
+ # ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
657
+ # ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
658
+ # ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
659
+ # ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
660
+ # normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
661
+ # normalizers.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
662
+ # normalizers.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
663
+ # normalizers.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
664
+ # select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
665
+ # select_green_tokens.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
666
+ # select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
667
+ # select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
668
+ #
669
+ #
670
+ # demo.queue(concurrency_count=3)
671
+ #
672
+ # if args.demo_public:
673
+ # demo.launch(share=True) # exposes app to the internet via randomly generated link
674
+ # else:
675
+ # demo.launch()
676
+ #
677
+ # def main(args):
678
+ # """Run a command line version of the generation and detection operations
679
+ # and optionally launch and serve the gradio demo"""
680
+ # # Initial arg processing and log
681
+ # args.normalizers = (args.normalizers.split(",") if args.normalizers else [])
682
+ # print(args)
683
+ #
684
+ # if not args.skip_model_load:
685
+ # model, tokenizer, device = load_model(args)
686
+ # else:
687
+ # model, tokenizer, device = None, None, None
688
+ #
689
+ # # Generate and detect, report to stdout
690
+ # if not args.skip_model_load:
691
+ # input_text = (
692
+ # "The diamondback terrapin or simply terrapin (Malaclemys terrapin) is a "
693
+ # "species of turtle native to the brackish coastal tidal marshes of the "
694
+ # "Northeastern and southern United States, and in Bermuda.[6] It belongs "
695
+ # "to the monotypic genus Malaclemys. It has one of the largest ranges of "
696
+ # "all turtles in North America, stretching as far south as the Florida Keys "
697
+ # "and as far north as Cape Cod.[7] The name 'terrapin' is derived from the "
698
+ # "Algonquian word torope.[8] It applies to Malaclemys terrapin in both "
699
+ # "British English and American English. The name originally was used by "
700
+ # "early European settlers in North America to describe these brackish-water "
701
+ # "turtles that inhabited neither freshwater habitats nor the sea. It retains "
702
+ # "this primary meaning in American English.[8] In British English, however, "
703
+ # "other semi-aquatic turtle species, such as the red-eared slider, might "
704
+ # "also be called terrapins. The common name refers to the diamond pattern "
705
+ # "on top of its shell (carapace), but the overall pattern and coloration "
706
+ # "vary greatly. The shell is usually wider at the back than in the front, "
707
+ # "and from above it appears wedge-shaped. The shell coloring can vary "
708
+ # "from brown to grey, and its body color can be grey, brown, yellow, "
709
+ # "or white. All have a unique pattern of wiggly, black markings or spots "
710
+ # "on their body and head. The diamondback terrapin has large webbed "
711
+ # "feet.[9] The species is"
712
+ # )
713
+ #
714
+ # args.default_prompt = input_text
715
+ #
716
+ # term_width = 80
717
+ # print("#"*term_width)
718
+ # print("Prompt:")
719
+ # print(input_text)
720
+ #
721
+ # _, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate(input_text,
722
+ # args,
723
+ # model=model,
724
+ # device=device,
725
+ # tokenizer=tokenizer)
726
+ # without_watermark_detection_result = detect(decoded_output_without_watermark,
727
+ # args,
728
+ # device=device,
729
+ # tokenizer=tokenizer)
730
+ # with_watermark_detection_result = detect(decoded_output_with_watermark,
731
+ # args,
732
+ # device=device,
733
+ # tokenizer=tokenizer)
734
+ #
735
+ # print("#"*term_width)
736
+ # print("Output without watermark:")
737
+ # print(decoded_output_without_watermark)
738
+ # print("-"*term_width)
739
+ # print(f"Detection result @ {args.detection_z_threshold}:")
740
+ # pprint(without_watermark_detection_result)
741
+ # print("-"*term_width)
742
+ #
743
+ # print("#"*term_width)
744
+ # print("Output with watermark:")
745
+ # print(decoded_output_with_watermark)
746
+ # print("-"*term_width)
747
+ # print(f"Detection result @ {args.detection_z_threshold}:")
748
+ # pprint(with_watermark_detection_result)
749
+ # print("-"*term_width)
750
+ #
751
+ #
752
+ # # Launch the app to generate and detect interactively (implements the hf space demo)
753
+ # if args.run_gradio:
754
+ # run_gradio(args, model=model, tokenizer=tokenizer, device=device)
755
+ #
756
+ # return
757
+ #
758
+ # if __name__ == "__main__":
759
+ #
760
+ # args = parse_args()
761
+ # print(args)
762
+ #
763
+ # main(args)
764
+
765
+ # coding=utf-8
766
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
767
+ # available at https://arxiv.org/abs/2301.10226
768
+ #
769
+ # Licensed under the Apache License, Version 2.0 (the "License");
770
+ # you may not use this file except in compliance with the License.
771
+ # You may obtain a copy of the License at
772
+ #
773
+ # http://www.apache.org/licenses/LICENSE-2.0
774
+ #
775
+ # Unless required by applicable law or agreed to in writing, software
776
+ # distributed under the License is distributed on an "AS IS" BASIS,
777
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
778
+ # See the License for the specific language governing permissions and
779
+ # limitations under the License.
780
+
781
+ # coding=utf-8
782
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
783
+ # available at https://arxiv.org/abs/2301.10226
784
+ #
785
+ # Licensed under the Apache License, Version 2.0 (the "License");
786
+ # you may not use this file except in compliance with the License.
787
+ # You may obtain a copy of the License at
788
+ #
789
+ # http://www.apache.org/licenses/LICENSE-2.0
790
+ #
791
+ # Unless required by applicable law or agreed to in writing, software
792
+ # distributed under the License is distributed on an "AS IS" BASIS,
793
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
794
+ # See the License for the specific language governing permissions and
795
+ # limitations under the License.
796
+
797
+ import os
798
+ import argparse
799
+ from pprint import pprint
800
+ from functools import partial
801
+
802
+ import numpy # for gradio hot reload
803
+ import gradio as gr
804
+
805
+ import torch
806
+
807
+ from transformers import (AutoTokenizer,
808
+ AutoModelForSeq2SeqLM,
809
+ AutoModelForCausalLM,
810
+ LogitsProcessorList)
811
+
812
+ # from local_tokenizers.tokenization_llama import LLaMATokenizer
813
+
814
+ from transformers import GPT2TokenizerFast
815
+
816
+ OPT_TOKENIZER = GPT2TokenizerFast
817
+
818
+ from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
819
+
820
+ # ALPACA_MODEL_NAME = "alpaca"
821
+ # ALPACA_MODEL_TOKENIZER = LLaMATokenizer
822
+ # ALPACA_TOKENIZER_PATH = "/cmlscratch/jkirchen/llama"
823
+
824
+ # FIXME correct lengths for all models
825
+ API_MODEL_MAP = {
826
+ "google/flan-ul2": {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
827
+ "google/flan-t5-xxl": {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
828
+ "EleutherAI/gpt-neox-20b": {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
829
+ # "bigscience/bloom" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
830
+ # "bigscience/bloomz" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
831
+ }
832
+
833
+
834
+ def str2bool(v):
835
+ """Util function for user friendly boolean flag args"""
836
+ if isinstance(v, bool):
837
+ return v
838
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
839
+ return True
840
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
841
+ return False
842
+ else:
843
+ raise argparse.ArgumentTypeError('Boolean value expected.')
844
+
845
+
846
+ def parse_args():
847
+ """Command line argument specification"""
848
+
849
+ parser = argparse.ArgumentParser(
850
+ description="A minimum working example of applying the watermark to any LLM that supports the huggingface 🤗 `generate` API")
851
+
852
+ parser.add_argument(
853
+ "--run_gradio",
854
+ type=str2bool,
855
+ default=True,
856
+ help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.",
857
+ )
858
+ parser.add_argument(
859
+ "--demo_public",
860
+ type=str2bool,
861
+ default=False,
862
+ help="Whether to expose the gradio demo to the internet.",
863
+ )
864
+ parser.add_argument(
865
+ "--model_name_or_path",
866
+ type=str,
867
+ default="facebook/opt-6.7b",
868
+ help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
869
+ )
870
+ parser.add_argument(
871
+ "--prompt_max_length",
872
+ type=int,
873
+ default=None,
874
+ help="Truncation length for prompt, overrides model config's max length field.",
875
+ )
876
+ parser.add_argument(
877
+ "--max_new_tokens",
878
+ type=int,
879
+ default=200,
880
+ help="Maximmum number of new tokens to generate.",
881
+ )
882
+ parser.add_argument(
883
+ "--generation_seed",
884
+ type=int,
885
+ default=123,
886
+ help="Seed for setting the torch global rng prior to generation.",
887
+ )
888
+ parser.add_argument(
889
+ "--use_sampling",
890
+ type=str2bool,
891
+ default=True,
892
+ help="Whether to generate using multinomial sampling.",
893
+ )
894
+ parser.add_argument(
895
+ "--sampling_temp",
896
+ type=float,
897
+ default=0.7,
898
+ help="Sampling temperature to use when generating using multinomial sampling.",
899
+ )
900
+ parser.add_argument(
901
+ "--n_beams",
902
+ type=int,
903
+ default=1,
904
+ help="Number of beams to use for beam search. 1 is normal greedy decoding",
905
+ )
906
+ parser.add_argument(
907
+ "--use_gpu",
908
+ type=str2bool,
909
+ default=True,
910
+ help="Whether to run inference and watermark hashing/seeding/permutation on gpu.",
911
+ )
912
+ parser.add_argument(
913
+ "--seeding_scheme",
914
+ type=str,
915
+ default="simple_1",
916
+ help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
917
+ )
918
+ parser.add_argument(
919
+ "--gamma",
920
+ type=float,
921
+ default=0.25,
922
+ help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
923
+ )
924
+ parser.add_argument(
925
+ "--delta",
926
+ type=float,
927
+ default=2.0,
928
+ help="The amount/bias to add to each of the greenlist token logits before each token sampling step.",
929
+ )
930
+ parser.add_argument(
931
+ "--normalizers",
932
+ type=str,
933
+ default="",
934
+ help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.",
935
+ )
936
+ parser.add_argument(
937
+ "--ignore_repeated_bigrams",
938
+ type=str2bool,
939
+ default=False,
940
+ help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.",
941
+ )
942
+ parser.add_argument(
943
+ "--detection_z_threshold",
944
+ type=float,
945
+ default=4.0,
946
+ help="The test statistic threshold for the detection hypothesis test.",
947
+ )
948
+ parser.add_argument(
949
+ "--select_green_tokens",
950
+ type=str2bool,
951
+ default=True,
952
+ help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.",
953
+ )
954
+ parser.add_argument(
955
+ "--skip_model_load",
956
+ type=str2bool,
957
+ default=False,
958
+ help="Skip the model loading to debug the interface.",
959
+ )
960
+ parser.add_argument(
961
+ "--seed_separately",
962
+ type=str2bool,
963
+ default=True,
964
+ help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.",
965
+ )
966
+ parser.add_argument(
967
+ "--load_fp16",
968
+ type=str2bool,
969
+ default=False,
970
+ help="Whether to run model in float16 precsion.",
971
+ )
972
+ args = parser.parse_args()
973
+ return args
974
+
975
+
976
+ def load_model(args):
977
+ """Load and return the model and tokenizer"""
978
+
979
+ args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5", "T0"]])
980
+ args.is_decoder_only_model = any(
981
+ [(model_type in args.model_name_or_path) for model_type in ["gpt", "opt", "bloom"]])
982
+ if args.is_seq2seq_model:
983
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
984
+ elif args.is_decoder_only_model:
985
+ if args.load_fp16:
986
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch.float16,
987
+ device_map='auto')
988
+ else:
989
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
990
+ else:
991
+ raise ValueError(f"Unknown model type: {args.model_name_or_path}")
992
+
993
+ if args.use_gpu:
994
+ device = "cuda" if torch.cuda.is_available() else "cpu"
995
+ if args.load_fp16:
996
+ pass
997
+ else:
998
+ model = model.to(device)
999
+ else:
1000
+ device = "cpu"
1001
+ model.eval()
1002
+
1003
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
1004
+
1005
+ return model, tokenizer, device
1006
+
1007
+
1008
+ from text_generation import InferenceAPIClient
1009
+ from requests.exceptions import ReadTimeout
1010
+
1011
+
1012
+ def generate_with_api(prompt, args):
1013
+ # hf_api_key = os.environ.get("HF_API_KEY")
1014
+ hf_api_key = "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"
1015
+ if hf_api_key is None:
1016
+ raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.")
1017
+
1018
+ client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key, timeout=60)
1019
+
1020
+ assert args.n_beams == 1, "HF API models do not support beam search."
1021
+ generation_params = {
1022
+ "max_new_tokens": args.max_new_tokens,
1023
+ "do_sample": args.use_sampling,
1024
+ }
1025
+ if args.use_sampling:
1026
+ generation_params["temperature"] = args.sampling_temp
1027
+ generation_params["seed"] = args.generation_seed
1028
+
1029
+ timeout_msg = "[Model API timeout error. Try reducing the max_new_tokens parameter or the prompt length.]"
1030
+ try:
1031
+ generation_params["watermark"] = False
1032
+ without_watermark_iterator = client.generate_stream(prompt, **generation_params)
1033
+ except ReadTimeout as e:
1034
+ print(e)
1035
+ without_watermark_iterator = (char for char in timeout_msg)
1036
+ try:
1037
+ generation_params["watermark"] = True
1038
+ with_watermark_iterator = client.generate_stream(prompt, **generation_params)
1039
+ except ReadTimeout as e:
1040
+ print(e)
1041
+ with_watermark_iterator = (char for char in timeout_msg)
1042
+
1043
+ all_without_words, all_with_words = "", ""
1044
+ for without_word, with_word in zip(without_watermark_iterator, with_watermark_iterator):
1045
+ all_without_words += without_word.token.text
1046
+ all_with_words += with_word.token.text
1047
+ yield all_without_words, all_with_words
1048
+
1049
+
1050
+ def check_prompt(prompt, args, tokenizer, model=None, device=None):
1051
+ # This applies to both the local and API model scenarios
1052
+ if args.model_name_or_path in API_MODEL_MAP:
1053
+ args.prompt_max_length = API_MODEL_MAP[args.model_name_or_path]["max_length"]
1054
+ elif hasattr(model.config, "max_position_embedding"):
1055
+ args.prompt_max_length = model.config.max_position_embeddings - args.max_new_tokens
1056
+ else:
1057
+ args.prompt_max_length = 2048 - args.max_new_tokens
1058
+
1059
+ tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True,
1060
+ max_length=args.prompt_max_length).to(device)
1061
+ truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
1062
+ redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
1063
+
1064
+ return (redecoded_input,
1065
+ int(truncation_warning),
1066
+ args)
1067
+
1068
+
1069
+ def generate(prompt, args, tokenizer, model=None, device=None):
1070
+ """Instatiate the WatermarkLogitsProcessor according to the watermark parameters
1071
+ and generate watermarked text by passing it to the generate method of the model
1072
+ as a logits processor. """
1073
+
1074
+ print(f"Generating with {args}")
1075
+ print(f"Prompt: {prompt}")
1076
+
1077
+ if args.model_name_or_path in API_MODEL_MAP:
1078
+ api_outputs = generate_with_api(prompt, args)
1079
+ yield from api_outputs
1080
+ else:
1081
+ tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True,
1082
+ max_length=args.prompt_max_length).to(device)
1083
+
1084
+ watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
1085
+ gamma=args.gamma,
1086
+ delta=args.delta,
1087
+ seeding_scheme=args.seeding_scheme,
1088
+ select_green_tokens=args.select_green_tokens)
1089
+
1090
+ gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
1091
+
1092
+ if args.use_sampling:
1093
+ gen_kwargs.update(dict(
1094
+ do_sample=True,
1095
+ top_k=0,
1096
+ temperature=args.sampling_temp
1097
+ ))
1098
+ else:
1099
+ gen_kwargs.update(dict(
1100
+ num_beams=args.n_beams
1101
+ ))
1102
+
1103
+ generate_without_watermark = partial(
1104
+ model.generate,
1105
+ **gen_kwargs
1106
+ )
1107
+ generate_with_watermark = partial(
1108
+ model.generate,
1109
+ logits_processor=LogitsProcessorList([watermark_processor]),
1110
+ **gen_kwargs
1111
+ )
1112
+
1113
+ torch.manual_seed(args.generation_seed)
1114
+ output_without_watermark = generate_without_watermark(**tokd_input)
1115
+
1116
+ # optional to seed before second generation, but will not be the same again generally, unless delta==0.0, no-op watermark
1117
+ if args.seed_separately:
1118
+ torch.manual_seed(args.generation_seed)
1119
+ output_with_watermark = generate_with_watermark(**tokd_input)
1120
+
1121
+ if args.is_decoder_only_model:
1122
+ # need to isolate the newly generated tokens
1123
+ output_without_watermark = output_without_watermark[:, tokd_input["input_ids"].shape[-1]:]
1124
+ output_with_watermark = output_with_watermark[:, tokd_input["input_ids"].shape[-1]:]
1125
+
1126
+ decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
1127
+ decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
1128
+
1129
+ # mocking the API outputs in a whitespace split generator style
1130
+ all_without_words, all_with_words = "", ""
1131
+ for without_word, with_word in zip(decoded_output_without_watermark.split(),
1132
+ decoded_output_with_watermark.split()):
1133
+ all_without_words += without_word + " "
1134
+ all_with_words += with_word + " "
1135
+ yield all_without_words, all_with_words
1136
+
1137
+
1138
+ def format_names(s):
1139
+ """Format names for the gradio demo interface"""
1140
+ s = s.replace("num_tokens_scored", "Tokens Counted (T)")
1141
+ s = s.replace("num_green_tokens", "# Tokens in Greenlist")
1142
+ s = s.replace("green_fraction", "Fraction of T in Greenlist")
1143
+ s = s.replace("z_score", "z-score")
1144
+ s = s.replace("p_value", "p value")
1145
+ s = s.replace("prediction", "Prediction")
1146
+ s = s.replace("confidence", "Confidence")
1147
+ return s
1148
+
1149
+
1150
+ def list_format_scores(score_dict, detection_threshold):
1151
+ """Format the detection metrics into a gradio dataframe input format"""
1152
+ lst_2d = []
1153
+ for k, v in score_dict.items():
1154
+ if k == 'green_fraction':
1155
+ lst_2d.append([format_names(k), f"{v:.1%}"])
1156
+ elif k == 'confidence':
1157
+ lst_2d.append([format_names(k), f"{v:.3%}"])
1158
+ elif isinstance(v, float):
1159
+ lst_2d.append([format_names(k), f"{v:.3g}"])
1160
+ elif isinstance(v, bool):
1161
+ lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")])
1162
+ else:
1163
+ lst_2d.append([format_names(k), f"{v}"])
1164
+ if "confidence" in score_dict:
1165
+ lst_2d.insert(-2, ["z-score Threshold", f"{detection_threshold}"])
1166
+ else:
1167
+ lst_2d.insert(-1, ["z-score Threshold", f"{detection_threshold}"])
1168
+ return lst_2d
1169
+
1170
+
1171
+ def detect(input_text, args, tokenizer, device=None, return_green_token_mask=True):
1172
+ """Instantiate the WatermarkDetection object and call detect on
1173
+ the input text returning the scores and outcome of the test"""
1174
+
1175
+ print(f"Detecting with {args}")
1176
+ print(f"Detection Tokenizer: {type(tokenizer)}")
1177
+
1178
+ watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
1179
+ gamma=args.gamma,
1180
+ seeding_scheme=args.seeding_scheme,
1181
+ device=device,
1182
+ tokenizer=tokenizer,
1183
+ z_threshold=args.detection_z_threshold,
1184
+ normalizers=args.normalizers,
1185
+ ignore_repeated_bigrams=args.ignore_repeated_bigrams,
1186
+ select_green_tokens=args.select_green_tokens)
1187
+ # for now, just don't display the green token mask
1188
+ # if we're using normalizers or ignore_repeated_bigrams
1189
+ if args.normalizers != [] or args.ignore_repeated_bigrams:
1190
+ return_green_token_mask = False
1191
+
1192
+ error = False
1193
+ green_token_mask = None
1194
+ if input_text == "":
1195
+ error = True
1196
+ else:
1197
+ try:
1198
+ score_dict = watermark_detector.detect(input_text, return_green_token_mask=return_green_token_mask)
1199
+ green_token_mask = score_dict.pop("green_token_mask", None)
1200
+ output = list_format_scores(score_dict, watermark_detector.z_threshold)
1201
+ except ValueError as e:
1202
+ print(e)
1203
+ error = True
1204
+ if error:
1205
+ output = [["Error", "string too short to compute metrics"]]
1206
+ output += [["", ""] for _ in range(6)]
1207
+
1208
+ html_output = "[No highlight markup generated]"
1209
+ if green_token_mask is not None:
1210
+ # hack bc we need a fast tokenizer with charspan support
1211
+ if "opt" in args.model_name_or_path:
1212
+ tokenizer = OPT_TOKENIZER.from_pretrained(args.model_name_or_path)
1213
+
1214
+ tokens = tokenizer(input_text)
1215
+ if tokens["input_ids"][0] == tokenizer.bos_token_id:
1216
+ tokens["input_ids"] = tokens["input_ids"][1:] # ignore attention mask
1217
+ skip = watermark_detector.min_prefix_len
1218
+ charspans = [tokens.token_to_chars(i) for i in range(skip, len(tokens["input_ids"]))]
1219
+ charspans = [cs for cs in charspans if cs is not None] # remove the special token spans
1220
+
1221
+ if len(charspans) != len(green_token_mask): breakpoint()
1222
+ assert len(charspans) == len(green_token_mask)
1223
+
1224
+ tags = [(
1225
+ f'<span class="green">{input_text[cs.start:cs.end]}</span>' if m else f'<span class="red">{input_text[cs.start:cs.end]}</span>')
1226
+ for cs, m in zip(charspans, green_token_mask)]
1227
+ html_output = f'<p>{" ".join(tags)}</p>'
1228
+
1229
+ return output, args, tokenizer, html_output
1230
+
1231
+
1232
+ def run_gradio(args, model=None, device=None, tokenizer=None):
1233
+ """Define and launch the gradio demo interface"""
1234
+ check_prompt_partial = partial(check_prompt, model=model, device=device)
1235
+ generate_partial = partial(generate, model=model, device=device)
1236
+ detect_partial = partial(detect, device=device)
1237
+
1238
+ css = """
1239
+ .green { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ccffcc; border-radius:0.5rem;}
1240
+ .red { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ffad99; border-radius:0.5rem;}
1241
+ """
1242
+
1243
+ # with gr.Blocks(theme="xiaobaiyuan/theme_brief") as demo:
1244
+ with gr.Blocks(css=css, theme="xiaobaiyuan/theme_brief") as demo:
1245
+ # Top section, greeting and instructions
1246
+ with gr.Row():
1247
+ with gr.Column(scale=9):
1248
+ gr.Markdown(
1249
+ """
1250
+ # 💧 大语言模型水印 🔍
1251
+ """
1252
+ )
1253
+ with gr.Column(scale=1):
1254
+ # if model_name_or_path at startup not one of the API models then add to dropdown
1255
+ all_models = sorted(list(set(list(API_MODEL_MAP.keys()) + [args.model_name_or_path])))
1256
+ model_selector = gr.Dropdown(
1257
+ all_models,
1258
+ value=args.model_name_or_path,
1259
+ label="Language Model",
1260
+ )
1261
+
1262
+ with gr.Accordion("参数说明", open=False):
1263
+ gr.Markdown(
1264
+ """
1265
+ - `z分数阈值` : 假设检验的截断值。
1266
+ - `标记个数 (T)` : 检测算法计算的输出中计数的标记数。
1267
+ 在简单的单个标记种子方案中,第一个标记被省略,因为它没有前缀标记,无法为其生成绿色列表。
1268
+ 在底部面板中描述的“忽略重复二元组”检测算法下,如果存在大量重复,这个数量可能远小于生成的总标记数。
1269
+ - `绿色列表中的标记数目` : 观察到的落在各自绿色列表中的标记数。
1270
+ - `T中含有绿色列表标记的比例` : `绿色列表中的标记数目` / `T`。预期对于人类/非水印文本,这个比例大约等于 gamma。
1271
+ - `z分数` : 检测假设检验的检验统计量。如果大于 `z分数阈值`,则“拒绝零假设”,即文本是人类/非水印的,推断它是带有水印的。
1272
+ - `p值` : 在零假设下观察到计算的 `z-分数` 的概率。
1273
+ 这是在不知道水印程序/绿色列表的情况下观察到 'T中含有绿色列表标记的比例' 的概率。
1274
+ 如果这个值非常小,我们有信心认为这么多绿色标记不是随机选择的。
1275
+ - `预测` : 假设检验的结果,即观察到的 `z分数` 是否高于 `z分数阈值`。
1276
+ - `置信度` : 如果我们拒绝零假设,并且 `预测` 是“Watermarked”,那么我们报告 1-`p 值` 来表示基于这个 `z分数` 观察的检测置信度的不可能性。
1277
+ """
1278
+ )
1279
+
1280
+ with gr.Accordion("关于模型能力的说明", open=True):
1281
+ gr.Markdown(
1282
+ """
1283
+ 本演示使用适用于单个 GPU 的开源语言模型��这些模型比专有商业工具(如 ChatGPT、Claude 或 Bard)的能力更弱。
1284
+
1285
+ 还有一件事,我们使用语言模型旨在“完成”您的提示,而不是经过微调以遵循指令的模型。
1286
+ 为了获得最佳结果,请使用一些组成段落开头的句子提示模型,然后让它“继续”您的段落。
1287
+ 一些示例包括维基百科文章的开头段落或故事的前几句话。
1288
+ 结尾处中断的较长提示将产生更流畅的生成。
1289
+ """
1290
+ )
1291
+
1292
+ # Construct state for parameters, define updates and toggles
1293
+ default_prompt = args.__dict__.pop("default_prompt")
1294
+ session_args = gr.State(value=args)
1295
+ # note that state obj automatically calls value if it's a callable, want to avoid calling tokenizer at startup
1296
+ session_tokenizer = gr.State(value=lambda: tokenizer)
1297
+ # with gr.Row():
1298
+ # gr.Markdown(
1299
+ # """
1300
+ # 温馨提示:若出现ERROR,可能由于api暂未成功载入,稍等片刻即可
1301
+ # """
1302
+ # )
1303
+ with gr.Tab("生成检测"):
1304
+ with gr.Row():
1305
+ prompt = gr.Textbox(label=f"提示词", interactive=True,lines=10,max_lines=10, value=default_prompt)
1306
+ with gr.Row():
1307
+ generate_btn = gr.Button("生成")
1308
+ with gr.Row():
1309
+ with gr.Column(scale=2):
1310
+ with gr.Tab("未嵌入水印输出的文本"):
1311
+ output_without_watermark = gr.Textbox(label=None, interactive=False, lines=14,
1312
+ max_lines=14, show_label=False)
1313
+ with gr.Tab("高亮"):
1314
+ html_without_watermark = gr.HTML(elem_id="html-without-watermark")
1315
+ with gr.Column(scale=1):
1316
+ # without_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
1317
+ without_watermark_detection_result = gr.Dataframe(headers=["参数", "值"], interactive=False,
1318
+ row_count=7, col_count=2)
1319
+
1320
+ with gr.Row():
1321
+ with gr.Column(scale=2):
1322
+ with gr.Tab("嵌入了水印输出的文本"):
1323
+ output_with_watermark = gr.Textbox(label=None, interactive=False, lines=14,
1324
+ max_lines=14, show_label=False)
1325
+ with gr.Tab("高亮"):
1326
+ html_with_watermark = gr.HTML(elem_id="html-with-watermark")
1327
+ with gr.Column(scale=1):
1328
+ # with_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
1329
+ with_watermark_detection_result = gr.Dataframe(headers=["参数", "值"], interactive=False,
1330
+ row_count=7, col_count=2)
1331
+
1332
+
1333
+ redecoded_input = gr.Textbox(visible=False)
1334
+ truncation_warning = gr.Number(visible=False)
1335
+ def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
1336
+ if truncation_warning:
1337
+ return redecoded_input + f"\n\n[由于长度原因,提示词被截断...]", args
1338
+ else:
1339
+ return orig_prompt, args
1340
+
1341
+ with gr.Tab("仅检测"):
1342
+ with gr.Row():
1343
+ with gr.Column(scale=2):
1344
+ with gr.Tab("待分析文本"):
1345
+ detection_input = gr.Textbox(interactive=True, lines=14, max_lines=14,show_label=False)
1346
+ with gr.Tab("高亮"):
1347
+ html_detection_input = gr.HTML(elem_id="html-detection-input")
1348
+ with gr.Column(scale=1):
1349
+ detection_result = gr.Dataframe(headers=["参数", "值"], interactive=False, row_count=7,
1350
+ col_count=2)
1351
+ with gr.Row():
1352
+ detect_btn = gr.Button("检测")
1353
+
1354
+ # Parameter selection group
1355
+ with gr.Accordion("高级设置", open=False):
1356
+ with gr.Row():
1357
+ with gr.Column(scale=1):
1358
+ gr.Markdown(f"#### 生成参数")
1359
+ with gr.Row():
1360
+ decoding = gr.Radio(label="解码方式", choices=["multinomial", "greedy"],
1361
+ value=("multinomial" if args.use_sampling else "greedy"))
1362
+
1363
+ with gr.Row():
1364
+ sampling_temp = gr.Slider(label="采样随机性多样性权重", minimum=0.1, maximum=1.0, step=0.1,
1365
+ value=args.sampling_temp, visible=True)
1366
+ with gr.Row():
1367
+ generation_seed = gr.Number(label="生成种子", value=args.generation_seed, interactive=True)
1368
+ with gr.Row():
1369
+ n_beams = gr.Dropdown(label="束搜索路数", choices=list(range(1, 11, 1)), value=args.n_beams,
1370
+ visible=(not args.use_sampling))
1371
+ with gr.Row():
1372
+ max_new_tokens = gr.Slider(label="生成最大标记数", minimum=10, maximum=1000, step=10,
1373
+ value=args.max_new_tokens)
1374
+
1375
+ with gr.Column(scale=1):
1376
+ gr.Markdown(f"#### 水印参数")
1377
+ with gr.Row():
1378
+ gamma = gr.Slider(label="gamma", minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
1379
+ with gr.Row():
1380
+ delta = gr.Slider(label="delta", minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
1381
+ gr.Markdown(f"#### 检测参数")
1382
+ with gr.Row():
1383
+ detection_z_threshold = gr.Slider(label="z-score 阈值", minimum=0.0, maximum=10.0, step=0.1,
1384
+ value=args.detection_z_threshold)
1385
+ with gr.Row():
1386
+ ignore_repeated_bigrams = gr.Checkbox(label="忽略重复 Bigram")
1387
+ with gr.Row():
1388
+ normalizers = gr.CheckboxGroup(label="正则化器",
1389
+ choices=["unicode", "homoglyphs", "truecase"],
1390
+ value=args.normalizers)
1391
+ # with gr.Accordion("Actual submitted parameters:",open=False):
1392
+ with gr.Row():
1393
+ gr.Markdown(
1394
+ f"_提示: 滑块更新有延迟。点击滑动条或使用右侧的数字窗口可以帮助更新。下方窗口显示当前的设置。_")
1395
+ with gr.Row():
1396
+ current_parameters = gr.Textbox(label="当前参数", value=args, interactive=False, lines=6)
1397
+ with gr.Accordion("保留设置", open=False):
1398
+ with gr.Row():
1399
+ with gr.Column(scale=1):
1400
+ seed_separately = gr.Checkbox(label="红绿分别生成", value=args.seed_separately)
1401
+ with gr.Column(scale=1):
1402
+ select_green_tokens = gr.Checkbox(label="从分区中选择'greenlist'",
1403
+ value=args.select_green_tokens)
1404
+
1405
+ with gr.Accordion("关于设置", open=False):
1406
+ gr.Markdown(
1407
+ """
1408
+ #### 生成参数:
1409
+
1410
+ - 解码方法:我们可以使用多项式采样或贪婪解码来从模型中生成标记。
1411
+ - 采样温度:如果使用多项式采样,可以设置采样分布的温度。
1412
+ 0.0 相当于贪婪解码,而 1.0 是下一个标记分布中的最大变异性/熵。
1413
+ 0.7 在保持对模型对前几个候选者的估计准确性的同时增加了多样性。对于贪婪解码无效。
1414
+ - 生成种子:在运行生成之前传递给 torch 随机数生成器的整数。使多项式采样策略输出可复现。对于贪婪解码无效。
1415
+ - 并行数:当使用贪婪解码时,还可以将并行数设置为 > 1 以启用波束搜索。
1416
+ 这在多项式采样中未实现/排除在论文中,但可能会在未来添加。
1417
+ - 最大生成标记数:传递给生成方法的 `max_new_tokens` 参数,以在特定数量的新标记处停止输出。
1418
+ 请注意,根据提示,模型可以生成较少的标记。
1419
+ 这将隐含地将可能的提示标记数量设置为模型的最大输入长度减去 `max_new_tokens`,
1420
+ 并且输入将相应地被截断。
1421
+
1422
+ #### 水印参数:
1423
+
1424
+ - gamma:每次生成步骤将词汇表分成绿色列表的部分。较小的 gamma 值通过使得有水印的模型能够更好地与人类/无水印文本区分,
1425
+ 从而创建了更强的水印,因为它会更倾向于从较小的绿色集合中进行采样,使得这些标记不太可能是偶然发生的。
1426
+ - delta:在每个生成步骤中,在采样/选择下一个标记之前,为绿色列表中的每个标记的对数概率添加正偏差。
1427
+ 较高的 delta 值意味着绿色列表标记更受有水印的模型青睐,并且随着偏差的增大,水印从“软性”过渡到“硬性”。
1428
+ 对于硬性水印,几乎所有的标记都是绿色的,但这可能对生成质量产生不利影响,特别是当分布的灵活性有限时。
1429
+
1430
+ #### 检测器参数:
1431
+
1432
+ - z-score 阈值:假设检验的 z-score 截断值。较高的阈值(例如 4.0)使得预测人类/无水印文本是有水印的
1433
+ (_false positives_)的可能性非常低,因为一个真正的包含大量标记的人类文本几乎不可能达到那么高的 z-score。
1434
+ 较低的阈值将捕捉更多的真正有水印的文本,因为一些有水印的文本可能包含较少的绿色标记并获得较低的 z-score,
1435
+ 但仍然通过较低的门槛被标记为“有水印”。然而,较低的阈值会增加被错误地标记为有水印的具有略高于平均绿色标记数的人类文本的几率。
1436
+ 4.0-5.0 提供了极低的误报率,同时仍然准确地捕捉到大多数有水印的文本。
1437
+ - 忽略重复的双字母组合:此备用检测算法在检测期间只考虑文本中的唯一双字母组合,
1438
+ 根据每对中的第一个计算绿色列表,并检查第二个是否在列表内。
1439
+ 这意味着 `T` 现在是文本中唯一的双字母组合的数量,
1440
+ 如果文本包含大量重复,那么它将少于生成的总标记数。
1441
+ 有关更详细的讨论,请参阅论文。
1442
+ - 标准化:我们实现了一些基本的标准化,以防止文本在检测过程中受到各种对抗性扰动。
1443
+ 目前,我们支持将所有字符转换为 Unicode,使用规范形式替换同形字符,并标准化大小写。
1444
+ """
1445
+ )
1446
+
1447
+
1448
+ # Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag, then call detection
1449
+ generate_btn.click(fn=check_prompt_partial, inputs=[prompt, session_args, session_tokenizer],
1450
+ outputs=[redecoded_input, truncation_warning, session_args]).success(
1451
+ fn=generate_partial, inputs=[redecoded_input, session_args, session_tokenizer],
1452
+ outputs=[output_without_watermark, output_with_watermark]).success(
1453
+ fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
1454
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
1455
+ html_without_watermark]).success(
1456
+ fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
1457
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark])
1458
+ # Show truncated version of prompt if truncation occurred
1459
+ redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input, truncation_warning, prompt, session_args],
1460
+ outputs=[prompt, session_args])
1461
+ # Register main detection tab click
1462
+ detect_btn.click(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
1463
+ outputs=[detection_result, session_args, session_tokenizer, html_detection_input],
1464
+ api_name="detection")
1465
+
1466
+ # State management logic
1467
+ # define update callbacks that change the state dict
1468
+ def update_model(session_state, value):
1469
+ session_state.model_name_or_path = value; return session_state
1470
+
1471
+ def update_sampling_temp(session_state, value):
1472
+ session_state.sampling_temp = float(value); return session_state
1473
+
1474
+ def update_generation_seed(session_state, value):
1475
+ session_state.generation_seed = int(value); return session_state
1476
+
1477
+ def update_gamma(session_state, value):
1478
+ session_state.gamma = float(value); return session_state
1479
+
1480
+ def update_delta(session_state, value):
1481
+ session_state.delta = float(value); return session_state
1482
+
1483
+ def update_detection_z_threshold(session_state, value):
1484
+ session_state.detection_z_threshold = float(value); return session_state
1485
+
1486
+ def update_decoding(session_state, value):
1487
+ if value == "multinomial":
1488
+ session_state.use_sampling = True
1489
+ elif value == "greedy":
1490
+ session_state.use_sampling = False
1491
+ return session_state
1492
+
1493
+ def toggle_sampling_vis(value):
1494
+ if value == "multinomial":
1495
+ return gr.update(visible=True)
1496
+ elif value == "greedy":
1497
+ return gr.update(visible=False)
1498
+
1499
+ def toggle_sampling_vis_inv(value):
1500
+ if value == "multinomial":
1501
+ return gr.update(visible=False)
1502
+ elif value == "greedy":
1503
+ return gr.update(visible=True)
1504
+
1505
+ # if model name is in the list of api models, set the num beams parameter to 1 and hide n_beams
1506
+ def toggle_vis_for_api_model(value):
1507
+ if value in API_MODEL_MAP:
1508
+ return gr.update(visible=False)
1509
+ else:
1510
+ return gr.update(visible=True)
1511
+
1512
+ def toggle_beams_for_api_model(value, orig_n_beams):
1513
+ if value in API_MODEL_MAP:
1514
+ return gr.update(value=1)
1515
+ else:
1516
+ return gr.update(value=orig_n_beams)
1517
+
1518
+ # if model name is in the list of api models, set the interactive parameter to false
1519
+ def toggle_interactive_for_api_model(value):
1520
+ if value in API_MODEL_MAP:
1521
+ return gr.update(interactive=False)
1522
+ else:
1523
+ return gr.update(interactive=True)
1524
+
1525
+ # if model name is in the list of api models, set gamma and delta based on API map
1526
+ def toggle_gamma_for_api_model(value, orig_gamma):
1527
+ if value in API_MODEL_MAP:
1528
+ return gr.update(value=API_MODEL_MAP[value]["gamma"])
1529
+ else:
1530
+ return gr.update(value=orig_gamma)
1531
+
1532
+ def toggle_delta_for_api_model(value, orig_delta):
1533
+ if value in API_MODEL_MAP:
1534
+ return gr.update(value=API_MODEL_MAP[value]["delta"])
1535
+ else:
1536
+ return gr.update(value=orig_delta)
1537
+
1538
+ def update_n_beams(session_state, value):
1539
+ session_state.n_beams = int(value); return session_state
1540
+
1541
+ def update_max_new_tokens(session_state, value):
1542
+ session_state.max_new_tokens = int(value); return session_state
1543
+
1544
+ def update_ignore_repeated_bigrams(session_state, value):
1545
+ session_state.ignore_repeated_bigrams = value; return session_state
1546
+
1547
+ def update_normalizers(session_state, value):
1548
+ session_state.normalizers = value; return session_state
1549
+
1550
+ def update_seed_separately(session_state, value):
1551
+ session_state.seed_separately = value; return session_state
1552
+
1553
+ def update_select_green_tokens(session_state, value):
1554
+ session_state.select_green_tokens = value; return session_state
1555
+
1556
+ def update_tokenizer(model_name_or_path):
1557
+ # if model_name_or_path == ALPACA_MODEL_NAME:
1558
+ # return ALPACA_MODEL_TOKENIZER.from_pretrained(ALPACA_TOKENIZER_PATH)
1559
+ # else:
1560
+ return AutoTokenizer.from_pretrained(model_name_or_path)
1561
+
1562
+ def check_model(value):
1563
+ return value if (value != "" and value is not None) else args.model_name_or_path
1564
+
1565
+ # enforce constraint that model cannot be null or empty
1566
+ # then attach model callbacks in particular
1567
+ model_selector.change(check_model, inputs=[model_selector], outputs=[model_selector]).then(
1568
+ toggle_vis_for_api_model, inputs=[model_selector], outputs=[n_beams]
1569
+ ).then(
1570
+ toggle_beams_for_api_model, inputs=[model_selector, n_beams], outputs=[n_beams]
1571
+ ).then(
1572
+ toggle_interactive_for_api_model, inputs=[model_selector], outputs=[gamma]
1573
+ ).then(
1574
+ toggle_interactive_for_api_model, inputs=[model_selector], outputs=[delta]
1575
+ ).then(
1576
+ toggle_gamma_for_api_model, inputs=[model_selector, gamma], outputs=[gamma]
1577
+ ).then(
1578
+ toggle_delta_for_api_model, inputs=[model_selector, delta], outputs=[delta]
1579
+ ).then(
1580
+ update_tokenizer, inputs=[model_selector], outputs=[session_tokenizer]
1581
+ ).then(
1582
+ update_model, inputs=[session_args, model_selector], outputs=[session_args]
1583
+ ).then(
1584
+ lambda value: str(value), inputs=[session_args], outputs=[current_parameters]
1585
+ )
1586
+ # registering callbacks for toggling the visibilty of certain parameters based on the values of others
1587
+ decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[sampling_temp])
1588
+ decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[generation_seed])
1589
+ decoding.change(toggle_sampling_vis_inv, inputs=[decoding], outputs=[n_beams])
1590
+ decoding.change(toggle_vis_for_api_model, inputs=[model_selector], outputs=[n_beams])
1591
+ # registering all state update callbacks
1592
+ decoding.change(update_decoding, inputs=[session_args, decoding], outputs=[session_args])
1593
+ sampling_temp.change(update_sampling_temp, inputs=[session_args, sampling_temp], outputs=[session_args])
1594
+ generation_seed.change(update_generation_seed, inputs=[session_args, generation_seed], outputs=[session_args])
1595
+ n_beams.change(update_n_beams, inputs=[session_args, n_beams], outputs=[session_args])
1596
+ max_new_tokens.change(update_max_new_tokens, inputs=[session_args, max_new_tokens], outputs=[session_args])
1597
+ gamma.change(update_gamma, inputs=[session_args, gamma], outputs=[session_args])
1598
+ delta.change(update_delta, inputs=[session_args, delta], outputs=[session_args])
1599
+ detection_z_threshold.change(update_detection_z_threshold, inputs=[session_args, detection_z_threshold],
1600
+ outputs=[session_args])
1601
+ ignore_repeated_bigrams.change(update_ignore_repeated_bigrams, inputs=[session_args, ignore_repeated_bigrams],
1602
+ outputs=[session_args])
1603
+ normalizers.change(update_normalizers, inputs=[session_args, normalizers], outputs=[session_args])
1604
+ seed_separately.change(update_seed_separately, inputs=[session_args, seed_separately], outputs=[session_args])
1605
+ select_green_tokens.change(update_select_green_tokens, inputs=[session_args, select_green_tokens],
1606
+ outputs=[session_args])
1607
+ # register additional callback on button clicks that updates the shown parameters window
1608
+ generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
1609
+ detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
1610
+ # When the parameters change, display the update and also fire detection, since some detection params dont change the model output.
1611
+ delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
1612
+ gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
1613
+ gamma.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
1614
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
1615
+ html_without_watermark])
1616
+ gamma.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
1617
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark])
1618
+ gamma.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
1619
+ outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
1620
+ detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
1621
+ detection_z_threshold.change(fn=detect_partial,
1622
+ inputs=[output_without_watermark, session_args, session_tokenizer],
1623
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
1624
+ html_without_watermark])
1625
+ detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
1626
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer,
1627
+ html_with_watermark])
1628
+ detection_z_threshold.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
1629
+ outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
1630
+ ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
1631
+ ignore_repeated_bigrams.change(fn=detect_partial,
1632
+ inputs=[output_without_watermark, session_args, session_tokenizer],
1633
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
1634
+ html_without_watermark])
1635
+ ignore_repeated_bigrams.change(fn=detect_partial,
1636
+ inputs=[output_with_watermark, session_args, session_tokenizer],
1637
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer,
1638
+ html_with_watermark])
1639
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
1640
+ outputs=[detection_result, session_args, session_tokenizer,
1641
+ html_detection_input])
1642
+ normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
1643
+ normalizers.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
1644
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
1645
+ html_without_watermark])
1646
+ normalizers.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
1647
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer,
1648
+ html_with_watermark])
1649
+ normalizers.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
1650
+ outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
1651
+ select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
1652
+ select_green_tokens.change(fn=detect_partial,
1653
+ inputs=[output_without_watermark, session_args, session_tokenizer],
1654
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
1655
+ html_without_watermark])
1656
+ select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
1657
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer,
1658
+ html_with_watermark])
1659
+ select_green_tokens.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
1660
+ outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
1661
+
1662
+ demo.queue(concurrency_count=3)
1663
+
1664
+ if args.demo_public:
1665
+ demo.launch(share=True) # exposes app to the internet via randomly generated link
1666
+ else:
1667
+ demo.launch()
1668
+
1669
+
1670
+ def main(args):
1671
+ """Run a command line version of the generation and detection operations
1672
+ and optionally launch and serve the gradio demo"""
1673
+ # Initial arg processing and log
1674
+ args.normalizers = (args.normalizers.split(",") if args.normalizers else [])
1675
+ print(args)
1676
+
1677
+ if not args.skip_model_load:
1678
+ model, tokenizer, device = load_model(args)
1679
+ else:
1680
+ model, tokenizer, device = None, None, None
1681
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
1682
+ if args.use_gpu:
1683
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1684
+ else:
1685
+ device = "cpu"
1686
+
1687
+ # terrapin example
1688
+ input_text = (
1689
+ "The diamondback terrapin or simply terrapin (Malaclemys terrapin) is a "
1690
+ "species of turtle native to the brackish coastal tidal marshes of the "
1691
+ "Northeastern and southern United States, and in Bermuda.[6] It belongs "
1692
+ "to the monotypic genus Malaclemys. It has one of the largest ranges of "
1693
+ "all turtles in North America, stretching as far south as the Florida Keys "
1694
+ "and as far north as Cape Cod.[7] The name 'terrapin' is derived from the "
1695
+ "Algonquian word torope.[8] It applies to Malaclemys terrapin in both "
1696
+ "British English and American English. The name originally was used by "
1697
+ "early European settlers in North America to describe these brackish-water "
1698
+ "turtles that inhabited neither freshwater habitats nor the sea. It retains "
1699
+ "this primary meaning in American English.[8] In British English, however, "
1700
+ "other semi-aquatic turtle species, such as the red-eared slider, might "
1701
+ "also be called terrapins. The common name refers to the diamond pattern "
1702
+ "on top of its shell (carapace), but the overall pattern and coloration "
1703
+ "vary greatly. The shell is usually wider at the back than in the front, "
1704
+ "and from above it appears wedge-shaped. The shell coloring can vary "
1705
+ "from brown to grey, and its body color can be grey, brown, yellow, "
1706
+ "or white. All have a unique pattern of wiggly, black markings or spots "
1707
+ "on their body and head. The diamondback terrapin has large webbed "
1708
+ "feet.[9] The species is"
1709
+ )
1710
+
1711
+ args.default_prompt = input_text
1712
+
1713
+ # Generate and detect, report to stdout
1714
+ if not args.skip_model_load:
1715
+
1716
+ term_width = 80
1717
+ print("#" * term_width)
1718
+ print("Prompt:")
1719
+ print(input_text)
1720
+
1721
+ # a generator that yields (without_watermark, with_watermark) pairs
1722
+ generator_outputs = generate(input_text,
1723
+ args,
1724
+ model=model,
1725
+ device=device,
1726
+ tokenizer=tokenizer)
1727
+ # we need to iterate over it,
1728
+ # but we only want the last output in this case
1729
+ for out in generator_outputs:
1730
+ decoded_output_without_watermark = out[0]
1731
+ decoded_output_with_watermark = out[1]
1732
+
1733
+ without_watermark_detection_result = detect(decoded_output_without_watermark,
1734
+ args,
1735
+ device=device,
1736
+ tokenizer=tokenizer,
1737
+ return_green_token_mask=False)
1738
+ with_watermark_detection_result = detect(decoded_output_with_watermark,
1739
+ args,
1740
+ device=device,
1741
+ tokenizer=tokenizer,
1742
+ return_green_token_mask=False)
1743
+
1744
+ print("#" * term_width)
1745
+ print("Output without watermark:")
1746
+ print(decoded_output_without_watermark)
1747
+ print("-" * term_width)
1748
+ print(f"Detection result @ {args.detection_z_threshold}:")
1749
+ pprint(without_watermark_detection_result)
1750
+ print("-" * term_width)
1751
+
1752
+ print("#" * term_width)
1753
+ print("Output with watermark:")
1754
+ print(decoded_output_with_watermark)
1755
+ print("-" * term_width)
1756
+ print(f"Detection result @ {args.detection_z_threshold}:")
1757
+ pprint(with_watermark_detection_result)
1758
+ print("-" * term_width)
1759
+
1760
+ # Launch the app to generate and detect interactively (implements the hf space demo)
1761
+ if args.run_gradio:
1762
+ run_gradio(args, model=model, tokenizer=tokenizer, device=device)
1763
+
1764
+ return
1765
+
1766
+
1767
+ if __name__ == "__main__":
1768
+ args = parse_args()
1769
+ print(args)
1770
+
1771
+ main(args)
1772
+
homoglyphs.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Updated version of core.py from
2
+ https://github.com/yamatt/homoglyphs/tree/main/homoglyphs_fork
3
+ for modern python3
4
+ """
5
+
6
+ from collections import defaultdict
7
+ import json
8
+ from itertools import product
9
+ import os
10
+ import unicodedata
11
+
12
+ # Actions if char not in alphabet
13
+ STRATEGY_LOAD = 1 # load category for this char
14
+ STRATEGY_IGNORE = 2 # add char to result
15
+ STRATEGY_REMOVE = 3 # remove char from result
16
+
17
+ ASCII_RANGE = range(128)
18
+
19
+
20
+ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
21
+ DATA_LOCATION = os.path.join(CURRENT_DIR, "homoglyph_data")
22
+
23
+
24
+ class Categories:
25
+ """
26
+ Work with aliases from ISO 15924.
27
+ https://en.wikipedia.org/wiki/ISO_15924#List_of_codes
28
+ """
29
+
30
+ fpath = os.path.join(DATA_LOCATION, "categories.json")
31
+
32
+ @classmethod
33
+ def _get_ranges(cls, categories):
34
+ """
35
+ :return: iter: (start code, end code)
36
+ :rtype: list
37
+ """
38
+ with open(cls.fpath, encoding="utf-8") as f:
39
+ data = json.load(f)
40
+
41
+ for category in categories:
42
+ if category not in data["aliases"]:
43
+ raise ValueError("Invalid category: {}".format(category))
44
+
45
+ for point in data["points"]:
46
+ if point[2] in categories:
47
+ yield point[:2]
48
+
49
+ @classmethod
50
+ def get_alphabet(cls, categories):
51
+ """
52
+ :return: set of chars in alphabet by categories list
53
+ :rtype: set
54
+ """
55
+ alphabet = set()
56
+ for start, end in cls._get_ranges(categories):
57
+ chars = (chr(code) for code in range(start, end + 1))
58
+ alphabet.update(chars)
59
+ return alphabet
60
+
61
+ @classmethod
62
+ def detect(cls, char):
63
+ """
64
+ :return: category
65
+ :rtype: str
66
+ """
67
+ with open(cls.fpath, encoding="utf-8") as f:
68
+ data = json.load(f)
69
+
70
+ # try detect category by unicodedata
71
+ try:
72
+ category = unicodedata.name(char).split()[0]
73
+ except (TypeError, ValueError):
74
+ # In Python2 unicodedata.name raise error for non-unicode chars
75
+ # Python3 raise ValueError for non-unicode characters
76
+ pass
77
+ else:
78
+ if category in data["aliases"]:
79
+ return category
80
+
81
+ # try detect category by ranges from JSON file.
82
+ code = ord(char)
83
+ for point in data["points"]:
84
+ if point[0] <= code <= point[1]:
85
+ return point[2]
86
+
87
+ @classmethod
88
+ def get_all(cls):
89
+ with open(cls.fpath, encoding="utf-8") as f:
90
+ data = json.load(f)
91
+ return set(data["aliases"])
92
+
93
+
94
+ class Languages:
95
+ fpath = os.path.join(DATA_LOCATION, "languages.json")
96
+
97
+ @classmethod
98
+ def get_alphabet(cls, languages):
99
+ """
100
+ :return: set of chars in alphabet by languages list
101
+ :rtype: set
102
+ """
103
+ with open(cls.fpath, encoding="utf-8") as f:
104
+ data = json.load(f)
105
+ alphabet = set()
106
+ for lang in languages:
107
+ if lang not in data:
108
+ raise ValueError("Invalid language code: {}".format(lang))
109
+ alphabet.update(data[lang])
110
+ return alphabet
111
+
112
+ @classmethod
113
+ def detect(cls, char):
114
+ """
115
+ :return: set of languages which alphabet contains passed char.
116
+ :rtype: set
117
+ """
118
+ with open(cls.fpath, encoding="utf-8") as f:
119
+ data = json.load(f)
120
+ languages = set()
121
+ for lang, alphabet in data.items():
122
+ if char in alphabet:
123
+ languages.add(lang)
124
+ return languages
125
+
126
+ @classmethod
127
+ def get_all(cls):
128
+ with open(cls.fpath, encoding="utf-8") as f:
129
+ data = json.load(f)
130
+ return set(data.keys())
131
+
132
+
133
+ class Homoglyphs:
134
+ def __init__(
135
+ self,
136
+ categories=None,
137
+ languages=None,
138
+ alphabet=None,
139
+ strategy=STRATEGY_IGNORE,
140
+ ascii_strategy=STRATEGY_IGNORE,
141
+ ascii_range=ASCII_RANGE,
142
+ ):
143
+ # strategies
144
+ if strategy not in (STRATEGY_LOAD, STRATEGY_IGNORE, STRATEGY_REMOVE):
145
+ raise ValueError("Invalid strategy")
146
+ self.strategy = strategy
147
+ self.ascii_strategy = ascii_strategy
148
+ self.ascii_range = ascii_range
149
+
150
+ # Homoglyphs must be initialized by any alphabet for correct work
151
+ if not categories and not languages and not alphabet:
152
+ categories = ("LATIN", "COMMON")
153
+
154
+ # cats and langs
155
+ self.categories = set(categories or [])
156
+ self.languages = set(languages or [])
157
+
158
+ # alphabet
159
+ self.alphabet = set(alphabet or [])
160
+ if self.categories:
161
+ alphabet = Categories.get_alphabet(self.categories)
162
+ self.alphabet.update(alphabet)
163
+ if self.languages:
164
+ alphabet = Languages.get_alphabet(self.languages)
165
+ self.alphabet.update(alphabet)
166
+ self.table = self.get_table(self.alphabet)
167
+
168
+ @staticmethod
169
+ def get_table(alphabet):
170
+ table = defaultdict(set)
171
+ with open(os.path.join(DATA_LOCATION, "confusables_sept2022.json")) as f:
172
+ data = json.load(f)
173
+ for char in alphabet:
174
+ if char in data:
175
+ for homoglyph in data[char]:
176
+ if homoglyph in alphabet:
177
+ table[char].add(homoglyph)
178
+ return table
179
+
180
+ @staticmethod
181
+ def get_restricted_table(source_alphabet, target_alphabet):
182
+ table = defaultdict(set)
183
+ with open(os.path.join(DATA_LOCATION, "confusables_sept2022.json")) as f:
184
+ data = json.load(f)
185
+ for char in source_alphabet:
186
+ if char in data:
187
+ for homoglyph in data[char]:
188
+ if homoglyph in target_alphabet:
189
+ table[char].add(homoglyph)
190
+ return table
191
+
192
+ @staticmethod
193
+ def uniq_and_sort(data):
194
+ result = list(set(data))
195
+ result.sort(key=lambda x: (-len(x), x))
196
+ return result
197
+
198
+ def _update_alphabet(self, char):
199
+ # try detect languages
200
+ langs = Languages.detect(char)
201
+ if langs:
202
+ self.languages.update(langs)
203
+ alphabet = Languages.get_alphabet(langs)
204
+ self.alphabet.update(alphabet)
205
+ else:
206
+ # try detect categories
207
+ category = Categories.detect(char)
208
+ if category is None:
209
+ return False
210
+ self.categories.add(category)
211
+ alphabet = Categories.get_alphabet([category])
212
+ self.alphabet.update(alphabet)
213
+ # update table for new alphabet
214
+ self.table = self.get_table(self.alphabet)
215
+ return True
216
+
217
+ def _get_char_variants(self, char):
218
+ if char not in self.alphabet:
219
+ if self.strategy == STRATEGY_LOAD:
220
+ if not self._update_alphabet(char):
221
+ return []
222
+ elif self.strategy == STRATEGY_IGNORE:
223
+ return [char]
224
+ elif self.strategy == STRATEGY_REMOVE:
225
+ return []
226
+
227
+ # find alternative chars for current char
228
+ alt_chars = self.table.get(char, set())
229
+ if alt_chars:
230
+ # find alternative chars for alternative chars for current char
231
+ alt_chars2 = [self.table.get(alt_char, set()) for alt_char in alt_chars]
232
+ # combine all alternatives
233
+ alt_chars.update(*alt_chars2)
234
+ # add current char to alternatives
235
+ alt_chars.add(char)
236
+
237
+ # uniq, sort and return
238
+ return self.uniq_and_sort(alt_chars)
239
+
240
+ def _get_combinations(self, text, ascii=False):
241
+ variations = []
242
+ for char in text:
243
+ alt_chars = self._get_char_variants(char)
244
+
245
+ if ascii:
246
+ alt_chars = [char for char in alt_chars if ord(char) in self.ascii_range]
247
+ if not alt_chars and self.ascii_strategy == STRATEGY_IGNORE:
248
+ return
249
+
250
+ if alt_chars:
251
+ variations.append(alt_chars)
252
+ if variations:
253
+ for variant in product(*variations):
254
+ yield "".join(variant)
255
+
256
+ def get_combinations(self, text):
257
+ return list(self._get_combinations(text))
258
+
259
+ def _to_ascii(self, text):
260
+ for variant in self._get_combinations(text, ascii=True):
261
+ if max(map(ord, variant)) in self.ascii_range:
262
+ yield variant
263
+
264
+ def to_ascii(self, text):
265
+ return self.uniq_and_sort(self._to_ascii(text))
normalizers.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Text-based normalizers, used to mitigate simple attacks against watermarking.
2
+
3
+ This implementation is unlikely to be a complete list of all possible exploits within the unicode standard,
4
+ it represents our best effort at the time of writing.
5
+
6
+ These normalizers can be used as stand-alone normalizers. They could be made to conform to HF tokenizers standard, but that would
7
+ require messing with the limited rust interface of tokenizers.NormalizedString
8
+ """
9
+ from collections import defaultdict
10
+ from functools import cache
11
+
12
+ import re
13
+ import unicodedata
14
+ import homoglyphs as hg
15
+
16
+
17
+ def normalization_strategy_lookup(strategy_name: str) -> object:
18
+ if strategy_name == "unicode":
19
+ return UnicodeSanitizer()
20
+ elif strategy_name == "homoglyphs":
21
+ return HomoglyphCanonizer()
22
+ elif strategy_name == "truecase":
23
+ return TrueCaser()
24
+
25
+
26
+ class HomoglyphCanonizer:
27
+ """Attempts to detect homoglyph attacks and find a consistent canon.
28
+
29
+ This function does so on a per-ISO-category level. Language-level would also be possible (see commented code).
30
+ """
31
+
32
+ def __init__(self):
33
+ self.homoglyphs = None
34
+
35
+ def __call__(self, homoglyphed_str: str) -> str:
36
+ # find canon:
37
+ target_category, all_categories = self._categorize_text(homoglyphed_str)
38
+ homoglyph_table = self._select_canon_category_and_load(target_category, all_categories)
39
+ return self._sanitize_text(target_category, homoglyph_table, homoglyphed_str)
40
+
41
+ def _categorize_text(self, text: str) -> dict:
42
+ iso_categories = defaultdict(int)
43
+ # self.iso_languages = defaultdict(int)
44
+
45
+ for char in text:
46
+ iso_categories[hg.Categories.detect(char)] += 1
47
+ # for lang in hg.Languages.detect(char):
48
+ # self.iso_languages[lang] += 1
49
+ target_category = max(iso_categories, key=iso_categories.get)
50
+ all_categories = tuple(iso_categories)
51
+ return target_category, all_categories
52
+
53
+ @cache
54
+ def _select_canon_category_and_load(self, target_category: str, all_categories: tuple[str]) -> dict:
55
+ homoglyph_table = hg.Homoglyphs(categories=(target_category, "COMMON")) # alphabet loaded here from file
56
+
57
+ source_alphabet = hg.Categories.get_alphabet(all_categories)
58
+ restricted_table = homoglyph_table.get_restricted_table(source_alphabet, homoglyph_table.alphabet) # table loaded here from file
59
+ return restricted_table
60
+
61
+ def _sanitize_text(self, target_category: str, homoglyph_table: dict, homoglyphed_str: str) -> str:
62
+ sanitized_text = ""
63
+ for char in homoglyphed_str:
64
+ # langs = hg.Languages.detect(char)
65
+ cat = hg.Categories.detect(char)
66
+ if target_category in cat or "COMMON" in cat or len(cat) == 0:
67
+ sanitized_text += char
68
+ else:
69
+ sanitized_text += list(homoglyph_table[char])[0]
70
+ return sanitized_text
71
+
72
+
73
+ class UnicodeSanitizer:
74
+ """Regex-based unicode sanitzer. Has different levels of granularity.
75
+
76
+ * ruleset="whitespaces" - attempts to remove only whitespace unicode characters
77
+ * ruleset="IDN.blacklist" - does its best to remove unusual unicode based on Network.IDN.blacklist characters
78
+ * ruleset="ascii" - brute-forces all text into ascii
79
+
80
+ This is unlikely to be a comprehensive list.
81
+
82
+ You can find a more comprehensive discussion at https://www.unicode.org/reports/tr36/
83
+ and https://www.unicode.org/faq/security.html
84
+ """
85
+
86
+ def __init__(self, ruleset="whitespaces"):
87
+ if ruleset == "whitespaces":
88
+
89
+ """Documentation:
90
+ \u00A0: Non-breaking space
91
+ \u1680: Ogham space mark
92
+ \u180E: Mongolian vowel separator
93
+ \u2000-\u200B: Various space characters, including en space, em space, thin space, hair space, zero-width space, and zero-width non-joiner
94
+ \u200C\u200D: Zero-width non-joiner and zero-width joiner
95
+ \u200E,\u200F: Left-to-right-mark, Right-to-left-mark
96
+ \u2060: Word joiner
97
+ \u2063: Invisible separator
98
+ \u202F: Narrow non-breaking space
99
+ \u205F: Medium mathematical space
100
+ \u3000: Ideographic space
101
+ \uFEFF: Zero-width non-breaking space
102
+ \uFFA0: Halfwidth hangul filler
103
+ \uFFF9\uFFFA\uFFFB: Interlinear annotation characters
104
+ \uFE00-\uFE0F: Variation selectors
105
+ \u202A-\u202F: Embedding characters
106
+ \u3164: Korean hangul filler.
107
+
108
+ Note that these characters are not always superfluous whitespace characters!
109
+ """
110
+
111
+ self.pattern = re.compile(
112
+ r"[\u00A0\u1680\u180E\u2000-\u200B\u200C\u200D\u200E\u200F\u2060\u2063\u202F\u205F\u3000\uFEFF\uFFA0\uFFF9\uFFFA\uFFFB"
113
+ r"\uFE00\uFE01\uFE02\uFE03\uFE04\uFE05\uFE06\uFE07\uFE08\uFE09\uFE0A\uFE0B\uFE0C\uFE0D\uFE0E\uFE0F\u3164\u202A\u202B\u202C\u202D"
114
+ r"\u202E\u202F]"
115
+ )
116
+ elif ruleset == "IDN.blacklist":
117
+
118
+ """Documentation:
119
+ [\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF]: Matches any whitespace characters in the Unicode character
120
+ set that are included in the IDN blacklist.
121
+ \uFFF9-\uFFFB: Matches characters that are not defined in Unicode but are used as language tags in various legacy encodings.
122
+ These characters are not allowed in domain names.
123
+ \uD800-\uDB7F: Matches the first part of a surrogate pair. Surrogate pairs are used to represent characters in the Unicode character
124
+ set that cannot be represented by a single 16-bit value. The first part of a surrogate pair is in the range U+D800 to U+DBFF,
125
+ and the second part is in the range U+DC00 to U+DFFF.
126
+ \uDB80-\uDBFF][\uDC00-\uDFFF]?: Matches the second part of a surrogate pair. The second part of a surrogate pair is in the range U+DC00
127
+ to U+DFFF, and is optional.
128
+ [\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]: Matches certain invalid UTF-16 sequences which should not appear in IDNs.
129
+ """
130
+
131
+ self.pattern = re.compile(
132
+ r"[\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF\uFFF9-\uFFFB\uD800-\uDB7F\uDB80-\uDBFF]"
133
+ r"[\uDC00-\uDFFF]?|[\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]"
134
+ )
135
+ else:
136
+ """Documentation:
137
+ This is a simple restriction to "no-unicode", using only ascii characters. Control characters are included.
138
+ """
139
+ self.pattern = re.compile(r"[^\x00-\x7F]+")
140
+
141
+ def __call__(self, text: str) -> str:
142
+ text = unicodedata.normalize("NFC", text) # canon forms
143
+ text = self.pattern.sub(" ", text) # pattern match
144
+ text = re.sub(" +", " ", text) # collapse whitespaces
145
+ text = "".join(c for c in text if unicodedata.category(c) != "Cc") # Remove any remaining non-printable characters
146
+ return text
147
+
148
+
149
+ class TrueCaser:
150
+ """True-casing, is a capitalization normalization that returns text to its original capitalization.
151
+
152
+ This defends against attacks that wRIte TeXt lIkE spOngBoB.
153
+
154
+ Here, a simple POS-tagger is used.
155
+ """
156
+
157
+ uppercase_pos = ["PROPN"] # Name POS tags that should be upper-cased
158
+
159
+ def __init__(self, backend="spacy"):
160
+ if backend == "spacy":
161
+ import spacy
162
+
163
+ self.nlp = spacy.load("en_core_web_sm")
164
+ self.normalize_fn = self._spacy_truecasing
165
+ else:
166
+ from nltk import pos_tag, word_tokenize # noqa
167
+ import nltk
168
+
169
+ nltk.download("punkt")
170
+ nltk.download("averaged_perceptron_tagger")
171
+ nltk.download("universal_tagset")
172
+ self.normalize_fn = self._nltk_truecasing
173
+
174
+ def __call__(self, random_capitalized_string: str) -> str:
175
+ truecased_str = self.normalize_fn(random_capitalized_string)
176
+ return truecased_str
177
+
178
+ def _spacy_truecasing(self, random_capitalized_string: str):
179
+ doc = self.nlp(random_capitalized_string.lower())
180
+ POS = self.uppercase_pos
181
+ truecased_str = "".join([w.text_with_ws.capitalize() if w.pos_ in POS or w.is_sent_start else w.text_with_ws for w in doc])
182
+ return truecased_str
183
+
184
+ def _nltk_truecasing(self, random_capitalized_string: str):
185
+ from nltk import pos_tag, word_tokenize
186
+ import nltk
187
+
188
+ nltk.download("punkt")
189
+ nltk.download("averaged_perceptron_tagger")
190
+ nltk.download("universal_tagset")
191
+ POS = ["NNP", "NNPS"]
192
+
193
+ tagged_text = pos_tag(word_tokenize(random_capitalized_string.lower()))
194
+ truecased_str = " ".join([w.capitalize() if p in POS else w for (w, p) in tagged_text])
195
+ return truecased_str
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ nltk
3
+ scipy
4
+ torch
5
+ transformers
6
+ tokenizers
7
+ text_generation
watermark_processor.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
3
+ # available at https://arxiv.org/abs/2301.10226
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from __future__ import annotations
18
+ import collections
19
+ from math import sqrt
20
+
21
+ import scipy.stats
22
+
23
+ import torch
24
+ from torch import Tensor
25
+ from tokenizers import Tokenizer
26
+ from transformers import LogitsProcessor
27
+
28
+ from nltk.util import ngrams
29
+
30
+ from normalizers import normalization_strategy_lookup
31
+
32
+ class WatermarkBase:
33
+ def __init__(
34
+ self,
35
+ vocab: list[int] = None,
36
+ gamma: float = 0.5,
37
+ delta: float = 2.0,
38
+ seeding_scheme: str = "simple_1", # mostly unused/always default
39
+ hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
40
+ select_green_tokens: bool = True,
41
+ ):
42
+
43
+ # watermarking parameters
44
+ self.vocab = vocab
45
+ self.vocab_size = len(vocab)
46
+ self.gamma = gamma
47
+ self.delta = delta
48
+ self.seeding_scheme = seeding_scheme
49
+ self.rng = None
50
+ self.hash_key = hash_key
51
+ self.select_green_tokens = select_green_tokens
52
+
53
+ def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None:
54
+ # can optionally override the seeding scheme,
55
+ # but uses the instance attr by default
56
+ if seeding_scheme is None:
57
+ seeding_scheme = self.seeding_scheme
58
+
59
+ if seeding_scheme == "simple_1":
60
+ assert input_ids.shape[-1] >= 1, f"seeding_scheme={seeding_scheme} requires at least a 1 token prefix sequence to seed rng"
61
+ prev_token = input_ids[-1].item()
62
+ self.rng.manual_seed(self.hash_key * prev_token)
63
+ else:
64
+ raise NotImplementedError(f"Unexpected seeding_scheme: {seeding_scheme}")
65
+ return
66
+
67
+ def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
68
+ # seed the rng using the previous tokens/prefix
69
+ # according to the seeding_scheme
70
+ self._seed_rng(input_ids)
71
+
72
+ greenlist_size = int(self.vocab_size * self.gamma)
73
+ vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
74
+ if self.select_green_tokens: # directly
75
+ greenlist_ids = vocab_permutation[:greenlist_size] # new
76
+ else: # select green via red
77
+ greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :] # legacy behavior
78
+ return greenlist_ids
79
+
80
+
81
+ class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
82
+
83
+ def __init__(self, *args, **kwargs):
84
+ super().__init__(*args, **kwargs)
85
+
86
+ def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
87
+ # TODO lets see if we can lose this loop
88
+ green_tokens_mask = torch.zeros_like(scores)
89
+ for b_idx in range(len(greenlist_token_ids)):
90
+ green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1
91
+ final_mask = green_tokens_mask.bool()
92
+ return final_mask
93
+
94
+ def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor:
95
+ scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
96
+ return scores
97
+
98
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
99
+
100
+ # this is lazy to allow us to colocate on the watermarked model's device
101
+ if self.rng is None:
102
+ self.rng = torch.Generator(device=input_ids.device)
103
+
104
+ # NOTE, it would be nice to get rid of this batch loop, but currently,
105
+ # the seed and partition operations are not tensor/vectorized, thus
106
+ # each sequence in the batch needs to be treated separately.
107
+ batched_greenlist_ids = [None for _ in range(input_ids.shape[0])]
108
+
109
+ for b_idx in range(input_ids.shape[0]):
110
+ greenlist_ids = self._get_greenlist_ids(input_ids[b_idx])
111
+ batched_greenlist_ids[b_idx] = greenlist_ids
112
+
113
+ green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=batched_greenlist_ids)
114
+
115
+ scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta)
116
+ return scores
117
+
118
+
119
+ class WatermarkDetector(WatermarkBase):
120
+ def __init__(
121
+ self,
122
+ *args,
123
+ device: torch.device = None,
124
+ tokenizer: Tokenizer = None,
125
+ z_threshold: float = 4.0,
126
+ normalizers: list[str] = ["unicode"], # or also: ["unicode", "homoglyphs", "truecase"]
127
+ ignore_repeated_bigrams: bool = False,
128
+ **kwargs,
129
+ ):
130
+ super().__init__(*args, **kwargs)
131
+ # also configure the metrics returned/preprocessing options
132
+ assert device, "Must pass device"
133
+ assert tokenizer, "Need an instance of the generating tokenizer to perform detection"
134
+
135
+ self.tokenizer = tokenizer
136
+ self.device = device
137
+ self.z_threshold = z_threshold
138
+ self.rng = torch.Generator(device=self.device)
139
+
140
+ if self.seeding_scheme == "simple_1":
141
+ self.min_prefix_len = 1
142
+ else:
143
+ raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}")
144
+
145
+ self.normalizers = []
146
+ for normalization_strategy in normalizers:
147
+ self.normalizers.append(normalization_strategy_lookup(normalization_strategy))
148
+
149
+ self.ignore_repeated_bigrams = ignore_repeated_bigrams
150
+ if self.ignore_repeated_bigrams:
151
+ assert self.seeding_scheme == "simple_1", "No repeated bigram credit variant assumes the single token seeding scheme."
152
+
153
+
154
+ def _compute_z_score(self, observed_count, T):
155
+ # count refers to number of green tokens, T is total number of tokens
156
+ expected_count = self.gamma
157
+ numer = observed_count - expected_count * T
158
+ denom = sqrt(T * expected_count * (1 - expected_count))
159
+ z = numer / denom
160
+ return z
161
+
162
+ def _compute_p_value(self, z):
163
+ p_value = scipy.stats.norm.sf(z)
164
+ return p_value
165
+
166
+ def _score_sequence(
167
+ self,
168
+ input_ids: Tensor,
169
+ return_num_tokens_scored: bool = True,
170
+ return_num_green_tokens: bool = True,
171
+ return_green_fraction: bool = True,
172
+ return_green_token_mask: bool = False,
173
+ return_z_score: bool = True,
174
+ return_p_value: bool = True,
175
+ ):
176
+ if self.ignore_repeated_bigrams:
177
+ # Method that only counts a green/red hit once per unique bigram.
178
+ # New num total tokens scored (T) becomes the number unique bigrams.
179
+ # We iterate over all unqiue token bigrams in the input, computing the greenlist
180
+ # induced by the first token in each, and then checking whether the second
181
+ # token falls in that greenlist.
182
+ assert return_green_token_mask == False, "Can't return the green/red mask when ignoring repeats."
183
+ bigram_table = {}
184
+ token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2)
185
+ freq = collections.Counter(token_bigram_generator)
186
+ num_tokens_scored = len(freq.keys())
187
+ for idx, bigram in enumerate(freq.keys()):
188
+ prefix = torch.tensor([bigram[0]], device=self.device) # expects a 1-d prefix tensor on the randperm device
189
+ greenlist_ids = self._get_greenlist_ids(prefix)
190
+ bigram_table[bigram] = True if bigram[1] in greenlist_ids else False
191
+ green_token_count = sum(bigram_table.values())
192
+ else:
193
+ num_tokens_scored = len(input_ids) - self.min_prefix_len
194
+ if num_tokens_scored < 1:
195
+ raise ValueError((f"Must have at least {1} token to score after "
196
+ f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme."))
197
+ # Standard method.
198
+ # Since we generally need at least 1 token (for the simplest scheme)
199
+ # we start the iteration over the token sequence with a minimum
200
+ # num tokens as the first prefix for the seeding scheme,
201
+ # and at each step, compute the greenlist induced by the
202
+ # current prefix and check if the current token falls in the greenlist.
203
+ green_token_count, green_token_mask = 0, []
204
+ for idx in range(self.min_prefix_len, len(input_ids)):
205
+ curr_token = input_ids[idx]
206
+ greenlist_ids = self._get_greenlist_ids(input_ids[:idx])
207
+ if curr_token in greenlist_ids:
208
+ green_token_count += 1
209
+ green_token_mask.append(True)
210
+ else:
211
+ green_token_mask.append(False)
212
+
213
+ score_dict = dict()
214
+ if return_num_tokens_scored:
215
+ score_dict.update(dict(num_tokens_scored=num_tokens_scored))
216
+ if return_num_green_tokens:
217
+ score_dict.update(dict(num_green_tokens=green_token_count))
218
+ if return_green_fraction:
219
+ score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored)))
220
+ if return_z_score:
221
+ score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
222
+ if return_p_value:
223
+ z_score = score_dict.get("z_score")
224
+ if z_score is None:
225
+ z_score = self._compute_z_score(green_token_count, num_tokens_scored)
226
+ score_dict.update(dict(p_value=self._compute_p_value(z_score)))
227
+ if return_green_token_mask:
228
+ score_dict.update(dict(green_token_mask=green_token_mask))
229
+
230
+ return score_dict
231
+
232
+ def detect(
233
+ self,
234
+ text: str = None,
235
+ tokenized_text: list[int] = None,
236
+ return_prediction: bool = True,
237
+ return_scores: bool = True,
238
+ z_threshold: float = None,
239
+ **kwargs,
240
+ ) -> dict:
241
+
242
+ assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"
243
+ if return_prediction:
244
+ kwargs["return_p_value"] = True # to return the "confidence":=1-p of positive detections
245
+
246
+ # run optional normalizers on text
247
+ for normalizer in self.normalizers:
248
+ text = normalizer(text)
249
+ if len(self.normalizers) > 0:
250
+ print(f"Text after normalization:\n\n{text}\n")
251
+
252
+ if tokenized_text is None:
253
+ assert self.tokenizer is not None, (
254
+ "Watermark detection on raw string ",
255
+ "requires an instance of the tokenizer ",
256
+ "that was used at generation time.",
257
+ )
258
+ tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device)
259
+ if tokenized_text[0] == self.tokenizer.bos_token_id:
260
+ tokenized_text = tokenized_text[1:]
261
+ else:
262
+ # try to remove the bos_tok at beginning if it's there
263
+ if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
264
+ tokenized_text = tokenized_text[1:]
265
+
266
+ # call score method
267
+ output_dict = {}
268
+ score_dict = self._score_sequence(tokenized_text, **kwargs)
269
+ if return_scores:
270
+ output_dict.update(score_dict)
271
+ # if passed return_prediction then perform the hypothesis test and return the outcome
272
+ if return_prediction:
273
+ z_threshold = z_threshold if z_threshold else self.z_threshold
274
+ assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
275
+ output_dict["prediction"] = score_dict["z_score"] > z_threshold
276
+ if output_dict["prediction"]:
277
+ output_dict["confidence"] = 1 - score_dict["p_value"]
278
+
279
+ return output_dict
280
+