Spaces:
Runtime error
Runtime error
Create data_predict.py
Browse files- data_predict.py +48 -0
data_predict.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import Dataset, DatasetDict
|
2 |
+
import pandas as pd
|
3 |
+
from config import max_length, label2id
|
4 |
+
from model import tokenizer
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def convert_to_stsb_features(example_batch):
|
10 |
+
inputs = example_batch['content']
|
11 |
+
features = tokenizer.batch_encode_plus(
|
12 |
+
inputs, truncation=True, max_length=max_length, padding='max_length')
|
13 |
+
|
14 |
+
# features["labels"] = [label2id[i] for i in example_batch["sentiment"]]
|
15 |
+
features["labels"] = [0]*len(example_batch["content"]) #[i for i in range(len(example_batch["content"]))]
|
16 |
+
# features["nid"] = [int(i) for i in example_batch["nid"]]
|
17 |
+
return features
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
def convert_to_features(dataset_dict, convert_func_dict):
|
23 |
+
columns_dict = {
|
24 |
+
"document": ['input_ids', 'attention_mask', 'labels'],
|
25 |
+
# "paragraph": ['input_ids', 'attention_mask', 'labels'],
|
26 |
+
# "sentence": ['input_ids', 'attention_mask', 'labels'],
|
27 |
+
}
|
28 |
+
features_dict = {}
|
29 |
+
|
30 |
+
for task_name, dataset in dataset_dict.items():
|
31 |
+
features_dict[task_name] = {}
|
32 |
+
print(task_name)
|
33 |
+
for phase, phase_dataset in dataset.items():
|
34 |
+
features_dict[task_name][phase] = phase_dataset.map(
|
35 |
+
convert_func_dict[task_name],
|
36 |
+
batched=True,
|
37 |
+
load_from_cache_file=False,
|
38 |
+
)
|
39 |
+
print(task_name, phase, len(phase_dataset),
|
40 |
+
len(features_dict[task_name][phase]))
|
41 |
+
features_dict[task_name][phase].set_format(
|
42 |
+
type="torch",
|
43 |
+
columns=columns_dict[task_name],
|
44 |
+
)
|
45 |
+
print("=>",task_name, phase, len(phase_dataset),
|
46 |
+
len(features_dict[task_name][phase]))
|
47 |
+
return features_dict
|
48 |
+
|