File size: 1,383 Bytes
6e97e97
 
 
ebdf1d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
---
license: cc-by-nd-4.0
---
## Czech Metrum Validator.
Validator for metrum. Trained on Czech poetry from github project by  
Institute of Czech Literature, Czech Academy of Sciences.

https://github.com/versotym/corpusCzechVerse

## Usage

### Loading model
Download validator.py with interface
Download model and load it by pytorch

```python
import torch
model: ValidatorInterface = (torch.load(args.metre_model_path_full, map_location=torch.device('cpu')))
```

Load base robeczech tokenizer and try it out

```python
tokenizer =  = AutoTokenizer.from_pretrained('roberta-base')
model.validate(input_ids=datum["input_ids"], metre=datum["metre"])['acc']
```

### Train Model

```python
meter_model = MeterValidator(pretrained_model=args.pretrained_model)
tokenizer =  AutoTokenizer.from_pretrained(args.tokenizer)

training_args = TrainingArguments(
  save_strategy  = "no",
  logging_steps = 500,
  warmup_steps = args.worm_up,
  weight_decay = 0.0,
  num_train_epochs = args.epochs,
  learning_rate = args.learning_rate,
  fp16 = True if torch.cuda.is_available() else False,
  ddp_backend = "nccl",
  lr_scheduler_type="cosine",
  logging_dir = './logs',
  output_dir = './results',
  per_device_train_batch_size = args.batch_size)

Trainer(model = rhyme_model,
  args = training_args,
  train_dataset= train_data.pytorch_dataset_body,
  data_collator=collate).train()

```