yhavinga commited on
Commit
8cd0b56
β€’
1 Parent(s): a19a543

Add some models

Browse files
Files changed (3) hide show
  1. app.py +45 -8
  2. generator.py +16 -17
  3. requirements.txt +2 -1
app.py CHANGED
@@ -12,23 +12,59 @@ TRANSLATION_NL_TO_EN = "translation_en_to_nl"
12
 
13
  GENERATOR_LIST = [
14
  {
15
- "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl",
16
- "desc": "longT5 large nl8 256cc/512beta/512l en->nl",
17
  "task": TRANSLATION_NL_TO_EN,
18
- "split_sentences": False,
 
 
 
 
 
 
19
  },
20
  {
21
- "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl",
22
- "desc": "longT5 large nl8 512beta/512l en->nl",
23
  "task": TRANSLATION_NL_TO_EN,
24
  "split_sentences": False,
25
  },
26
  {
27
- "model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
28
- "desc": "T5 small nl24 ccmatrix en->nl",
29
  "task": TRANSLATION_NL_TO_EN,
30
  "split_sentences": True,
31
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  ]
33
 
34
 
@@ -64,7 +100,7 @@ It was a quite young girl, unknown to me, with a hood over her head, and with la
64
 
65
  β€œMy father is very ill,” she said without a word of introduction. β€œThe nurse is frightened. Could you come in and help?”"""
66
  st.session_state["text"] = st.text_area(
67
- "Enter text", st.session_state.prompt_box, height=300
68
  )
69
  num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=1)
70
  num_beam_groups = st.sidebar.number_input(
@@ -83,6 +119,7 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
83
  "num_beams": num_beams,
84
  "num_beam_groups": num_beam_groups,
85
  "length_penalty": length_penalty,
 
86
  }
87
 
88
  if st.button("Run"):
 
12
 
13
  GENERATOR_LIST = [
14
  {
15
+ "model_name": "Helsinki-NLP/opus-mt-en-nl",
16
+ "desc": "Opus MT en->nl",
17
  "task": TRANSLATION_NL_TO_EN,
18
+ "split_sentences": True,
19
+ },
20
+ {
21
+ "model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
22
+ "desc": "T5 small nl24 ccmatrix en->nl",
23
+ "task": TRANSLATION_NL_TO_EN,
24
+ "split_sentences": True,
25
  },
26
  {
27
+ "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl",
28
+ "desc": "longT5 large nl8 256cc/512beta/512l en->nl",
29
  "task": TRANSLATION_NL_TO_EN,
30
  "split_sentences": False,
31
  },
32
  {
33
+ "model_name": "yhavinga/byt5-small-ccmatrix-en-nl",
34
+ "desc": "ByT5 small ccmatrix en->nl",
35
  "task": TRANSLATION_NL_TO_EN,
36
  "split_sentences": True,
37
  },
38
+ # {
39
+ # "model_name": "yhavinga/t5-eff-large-8l-nedd-en-nl",
40
+ # "desc": "T5 eff large nl8 en->nl",
41
+ # "task": TRANSLATION_NL_TO_EN,
42
+ # "split_sentences": True,
43
+ # },
44
+ # {
45
+ # "model_name": "yhavinga/t5-base-36L-ccmatrix-multi",
46
+ # "desc": "T5 base nl36 ccmatrix en->nl",
47
+ # "task": TRANSLATION_NL_TO_EN,
48
+ # "split_sentences": True,
49
+ # },
50
+ # {
51
+ # "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl",
52
+ # "desc": "longT5 large nl8 512beta/512l en->nl",
53
+ # "task": TRANSLATION_NL_TO_EN,
54
+ # "split_sentences": False,
55
+ # },
56
+ # {
57
+ # "model_name": "yhavinga/t5-base-36L-nedd-x-en-nl-300",
58
+ # "desc": "T5 base 36L nedd en->nl 300",
59
+ # "task": TRANSLATION_NL_TO_EN,
60
+ # "split_sentences": True,
61
+ # },
62
+ # {
63
+ # "model_name": "yhavinga/long-t5-local-small-ccmatrix-en-nl",
64
+ # "desc": "longT5 small ccmatrix en->nl",
65
+ # "task": TRANSLATION_NL_TO_EN,
66
+ # "split_sentences": True,
67
+ # },
68
  ]
69
 
70
 
 
100
 
101
  β€œMy father is very ill,” she said without a word of introduction. β€œThe nurse is frightened. Could you come in and help?”"""
102
  st.session_state["text"] = st.text_area(
103
+ "Enter text", st.session_state.prompt_box, height=250
104
  )
105
  num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=1)
106
  num_beam_groups = st.sidebar.number_input(
 
119
  "num_beams": num_beams,
120
  "num_beam_groups": num_beam_groups,
121
  "length_penalty": length_penalty,
122
+ "early_stopping": True,
123
  }
124
 
125
  if st.button("Run"):
generator.py CHANGED
@@ -30,9 +30,19 @@ def load_model(model_name):
30
  if tokenizer.pad_token is None:
31
  print("Adding pad_token to the tokenizer")
32
  tokenizer.pad_token = tokenizer.eos_token
33
- model = AutoModelForSeq2SeqLM.from_pretrained(
34
- model_name, from_flax=True, use_auth_token=get_access_token()
35
- )
 
 
 
 
 
 
 
 
 
 
36
  if device != -1:
37
  model.to(f"cuda:{device}")
38
  return tokenizer, model
@@ -66,24 +76,13 @@ class Generator:
66
  for key in self.gen_kwargs:
67
  if key in self.model.config.__dict__:
68
  self.gen_kwargs[key] = self.model.config.__dict__[key]
69
- print(
70
- "Setting",
71
- key,
72
- "to",
73
- self.gen_kwargs[key],
74
- "for model",
75
- self.model_name,
76
- )
77
  try:
78
  if self.task in self.model.config.task_specific_params:
79
  task_specific_params = self.model.config.task_specific_params[
80
  self.task
81
  ]
82
- self.prefix = (
83
- task_specific_params["prefix"]
84
- if "prefix" in task_specific_params
85
- else ""
86
- )
87
  for key in self.gen_kwargs:
88
  if key in task_specific_params:
89
  self.gen_kwargs[key] = task_specific_params[key]
@@ -95,7 +94,7 @@ class Generator:
95
  text = re.sub(r"\n{2,}", "\n", text)
96
  generate_kwargs = {**self.gen_kwargs, **generate_kwargs}
97
 
98
- # if there are newlines in the text, and the model needs line-splitting, split the text
99
  if re.search(r"\n", text) and self.split_sentences:
100
  lines = text.splitlines()
101
  translated = [self.generate(line, **generate_kwargs)[0] for line in lines]
 
30
  if tokenizer.pad_token is None:
31
  print("Adding pad_token to the tokenizer")
32
  tokenizer.pad_token = tokenizer.eos_token
33
+ try:
34
+ model = AutoModelForSeq2SeqLM.from_pretrained(
35
+ model_name, use_auth_token=get_access_token()
36
+ )
37
+ except EnvironmentError:
38
+ try:
39
+ model = AutoModelForSeq2SeqLM.from_pretrained(
40
+ model_name, from_flax=True, use_auth_token=get_access_token()
41
+ )
42
+ except EnvironmentError:
43
+ model = AutoModelForSeq2SeqLM.from_pretrained(
44
+ model_name, from_tf=True, use_auth_token=get_access_token()
45
+ )
46
  if device != -1:
47
  model.to(f"cuda:{device}")
48
  return tokenizer, model
 
76
  for key in self.gen_kwargs:
77
  if key in self.model.config.__dict__:
78
  self.gen_kwargs[key] = self.model.config.__dict__[key]
 
 
 
 
 
 
 
 
79
  try:
80
  if self.task in self.model.config.task_specific_params:
81
  task_specific_params = self.model.config.task_specific_params[
82
  self.task
83
  ]
84
+ if "prefix" in task_specific_params:
85
+ self.prefix = task_specific_params["prefix"]
 
 
 
86
  for key in self.gen_kwargs:
87
  if key in task_specific_params:
88
  self.gen_kwargs[key] = task_specific_params[key]
 
94
  text = re.sub(r"\n{2,}", "\n", text)
95
  generate_kwargs = {**self.gen_kwargs, **generate_kwargs}
96
 
97
+ # if there are newlines in the text, and the model needs line-splitting, split the text and recurse
98
  if re.search(r"\n", text) and self.split_sentences:
99
  lines = text.splitlines()
100
  translated = [self.generate(line, **generate_kwargs)[0] for line in lines]
requirements.txt CHANGED
@@ -5,9 +5,10 @@ protobuf<3.20
5
  streamlit>=1.4.0,<=1.10.0
6
  torch
7
  transformers>=4.13.0
8
- mtranslate
9
  psutil
10
  jax[cuda]==0.3.16
11
  chex>=0.1.4
12
  ##jaxlib==0.1.67
13
  flax>=0.5.3
 
 
5
  streamlit>=1.4.0,<=1.10.0
6
  torch
7
  transformers>=4.13.0
8
+ langdetect
9
  psutil
10
  jax[cuda]==0.3.16
11
  chex>=0.1.4
12
  ##jaxlib==0.1.67
13
  flax>=0.5.3
14
+ sentencepiece