File size: 5,666 Bytes
e1aa577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os.path
import logging
import pandas as pd
from pathlib import Path
from datetime import datetime
import csv

from utils.dedup import Dedup

class DatasetBase:
    """
    This class store and manage all the dataset records (including the annotations and prediction)
    """

    def __init__(self, config):
        if config.records_path is None:
            self.records = pd.DataFrame(columns=['id', 'text', 'prediction',
                                                 'annotation', 'metadata', 'score', 'batch_id'])
        else:
            self.records = pd.read_csv(config.records_path)
        dt_string = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")

        self.name = config.name + '__' + dt_string
        self.label_schema = config.label_schema
        self.dedup = Dedup(config)
        self.sample_size = config.get("sample_size", 3)
        self.semantic_sampling = config.get("semantic_sampling", False)
        if not config.get('dedup_new_samples', False):
            self.remove_duplicates = self._null_remove

    def __len__(self):
        """
        Return the number of samples in the dataset.
        """
        return len(self.records)

    def __getitem__(self, batch_idx):
        """
        Return the batch idx.
        """
        extract_records = self.records[self.records['batch_id'] == batch_idx]
        extract_records = extract_records.reset_index(drop=True)
        return extract_records

    def get_leq(self, batch_idx):
        """
        Return all the records up to batch_idx (includes).
        """
        extract_records = self.records[self.records['batch_id'] <= batch_idx]
        extract_records = extract_records.reset_index(drop=True)
        return extract_records

    def add(self, sample_list: dict = None, batch_id: int = None, records: pd.DataFrame = None):
        """
        Add records to the dataset.
        :param sample_list: The samples to add in a dict structure (only used in case record=None)
        :param batch_id: The batch_id for the upload records (only used in case record= None)
        :param records: dataframes, update using pandas
        """
        if records is None:
            records = pd.DataFrame([{'id': len(self.records) + i, 'text': sample, 'batch_id': batch_id} for
                       i, sample in enumerate(sample_list)])
        self.records = pd.concat([self.records, records], ignore_index=True)

    def update(self, records: pd.DataFrame):
        """
        Update records in dataset.
        """
        # Ignore if records is empty
        if len(records) == 0:
            return

        # Set 'id' as the index for both DataFrames
        records.set_index('id', inplace=True)
        self.records.set_index('id', inplace=True)

        # Update using 'id' as the key
        self.records.update(records)

        # Remove null annotations
        if len(self.records.loc[self.records["annotation"]=="Discarded"]) > 0:
            discarded_annotation_records = self.records.loc[self.records["annotation"]=="Discarded"]
            #TODO: direct `discarded_annotation_records` to another dataset to be used later for corner-cases
            self.records = self.records.loc[self.records["annotation"]!="Discarded"]

        # Reset index
        self.records.reset_index(inplace=True)

    def modify(self, index: int, record: dict):
        """
        Modify a record in the dataset.
        """
        self.records[index] = record

    def apply(self, function, column_name: str):
        """
        Apply function on each record.
        """
        self.records[column_name] = self.records.apply(function, axis=1)

    def save_dataset(self, path: Path):
        self.records.to_csv(path, index=False, quoting=csv.QUOTE_NONNUMERIC)

    def load_dataset(self, path: Path):
        """
        Loading dataset
        :param path: path for the csv
        """
        if os.path.isfile(path):
            self.records = pd.read_csv(path, dtype={'annotation': str, 'prediction': str, 'batch_id': int})
        else:
            logging.warning('Dataset dump not found, initializing from zero')

    def remove_duplicates(self, samples: list) -> list:
        """
        Remove (soft) duplicates from the given samples
        :param samples: The samples
        :return: The samples without duplicates
        """
        dd = self.dedup.copy()
        df = pd.DataFrame(samples, columns=['text'])
        df_dedup = dd.sample(df, operation_function=min)
        return df_dedup['text'].tolist()

    def _null_remove(self, samples: list) -> list:
        # Identity function that returns the input unmodified
        return samples

    def sample_records(self, n: int = None) -> pd.DataFrame:
        """
        Return a sample of the records after semantic clustering
        :param n: The number of samples to return
        :return: A sample of the records
        """
        n = n or self.sample_size
        if self.semantic_sampling:
            dd = self.dedup.copy()
            df_samples = dd.sample(self.records).head(n)

            if len(df_samples) < n:
                df_samples = self.records.head(n)
        else:
            df_samples = self.records.sample(n)
        return df_samples

    @staticmethod
    def samples_to_text(records: pd.DataFrame) -> str:
        """
        Return a string that organize the samples for a meta-prompt
        :param records: The samples for the step
        :return: A string that contains the organized samples
        """
        txt_res = '##\n'
        for i, row in records.iterrows():
            txt_res += f"Sample:\n {row.text}\n#\n"
        return txt_res