wq2012 commited on
Commit
ef5762f
Β·
verified Β·
1 Parent(s): 63d5d6f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +122 -3
README.md CHANGED
@@ -1,3 +1,122 @@
1
- ---
2
- license: llama3
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: llama3
3
+ ---
4
+
5
+ **This is not an officially supported Google product.**
6
+
7
+ ## Overview
8
+
9
+ [DiarizationLM](https://arxiv.org/abs/2401.03506) model finetuned
10
+ on the training subset of the Fisher corpus.
11
+
12
+ * Foundation model: [unsloth/llama-3-8b-bnb-4bit](https://huggingface.co/unsloth/llama-3-8b-bnb-4bit)
13
+ * Finetuning scripts: https://github.com/google/speaker-id/tree/master/DiarizationLM/unsloth
14
+
15
+ This difference between this model and [google/DiarizationLM-8b-Fisher-v1](https://huggingface.co/google/DiarizationLM-8b-Fisher-v1):
16
+
17
+ * For this model, the loss is only computed on the completion tokens.
18
+ * For `google/DiarizationLM-8b-Fisher-v1`, the loss is computed also on the prompt tokens.
19
+
20
+ ## Training config
21
+
22
+ This model is finetuned on the training subset of the Fisher corpus, using a LoRA adapter of rank 256. The total number of training parameters is 671,088,640. With a batch size of 16, this model has been trained for 28800 steps, which is ~9 epochs of the training data.
23
+
24
+ We use the `mixed` flavor during our training, meaning we combine data from `hyp2ora` and `deg2ref` flavors. After the prompt builder, we have a total of 51,063 prompt-completion pairs in our training set.
25
+
26
+ The finetuning took more than 4 days on a Google Cloud VM instance that has one NVIDIA A100 GPU with 80GB memory.
27
+
28
+ The maximal length of the prompt to this model is 6000 characters, including the " --> " suffix. The maximal sequence length is 4096 tokens.
29
+
30
+ ## Metrics
31
+
32
+ ### Fisher testing set
33
+
34
+ | System | WER (%) | WDER (%) | cpWER (%) |
35
+ | ------- | ------- | -------- | --------- |
36
+ | USM + turn-to-diarize baseline | 15.48 | 5.32 | 21.19 |
37
+ | + This model | - | 3.28 | 18.37 |
38
+
39
+ ### Callhome testing set
40
+
41
+ | System | WER (%) | WDER (%) | cpWER (%) |
42
+ | ------- | ------- | -------- | --------- |
43
+ | USM + turn-to-diarize baseline | 15.36 | 7.72 | 24.39 |
44
+ | + This model | - | 6.66 | 23.57 |
45
+
46
+ ## Usage
47
+
48
+ First, you need to install two packages:
49
+
50
+ ```
51
+ pip install transformers diarizationlm
52
+ ```
53
+
54
+ On a machine with GPU and CUDA, you can use the model by running the following script:
55
+
56
+ ```python
57
+ from transformers import LlamaForCausalLM, AutoTokenizer
58
+ from diarizationlm import utils
59
+
60
+ HYPOTHESIS = """<speaker:1> Hello, how are you doing <speaker:2> today? I am doing well. What about <speaker:1> you? I'm doing well, too. Thank you."""
61
+
62
+ print("Loading model...")
63
+ tokenizer = AutoTokenizer.from_pretrained("google/DiarizationLM-8b-Fisher-v2", device_map="cuda")
64
+ model = LlamaForCausalLM.from_pretrained("google/DiarizationLM-8b-Fisher-v2", device_map="cuda")
65
+
66
+ print("Tokenizing input...")
67
+ inputs = tokenizer([HYPOTHESIS + " --> "], return_tensors = "pt").to("cuda")
68
+
69
+ print("Generating completion...")
70
+ outputs = model.generate(**inputs,
71
+ max_new_tokens = inputs.input_ids.shape[1] * 1.2,
72
+ use_cache = False)
73
+
74
+ print("Decoding completion...")
75
+ completion = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:],
76
+ skip_special_tokens = True)[0]
77
+
78
+ print("Transferring completion to hypothesis text...")
79
+ transferred_completion = utils.transfer_llm_completion(completion, HYPOTHESIS)
80
+
81
+ print("========================================")
82
+ print("Hypothesis:", HYPOTHESIS)
83
+ print("========================================")
84
+ print("Completion:", completion)
85
+ print("========================================")
86
+ print("Transferred completion:", transferred_completion)
87
+ print("========================================")
88
+ ```
89
+
90
+ The output will look like below:
91
+
92
+ ```
93
+ Loading model...
94
+ Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
95
+ Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:13<00:00, 3.32s/it]
96
+ generation_config.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 172/172 [00:00<00:00, 992kB/s]
97
+ Tokenizing input...
98
+ Generating completion...
99
+ Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
100
+ Decoding completion...
101
+ Transferring completion to hypothesis text...
102
+ ========================================
103
+ Hypothesis: <speaker:1> Hello, how are you doing <speaker:2> today? I am doing well. What about <speaker:1> you? I'm doing well, too. Thank you.
104
+ ========================================
105
+ Completion: <speaker:1> Hello, how are you doing today? <speaker:2> i am doing well. What about you? <speaker:1> i'm doing well, too. Thank you. [eod] [eod] <speaker:2
106
+ ========================================
107
+ Transferred completion: <speaker:1> Hello, how are you doing today? <speaker:2> I am doing well. What about you? <speaker:1> I'm doing well, too. Thank you.
108
+ ========================================
109
+ ```
110
+
111
+ ## Citation
112
+
113
+ Our paper is cited as:
114
+
115
+ ```
116
+ @article{wang2024diarizationlm,
117
+ title={{DiarizationLM: Speaker Diarization Post-Processing with Large Language Models}},
118
+ author={Quan Wang and Yiling Huang and Guanlong Zhao and Evan Clark and Wei Xia and Hank Liao},
119
+ journal={arXiv preprint arXiv:2401.03506},
120
+ year={2024}
121
+ }
122
+ ```