def drop_duplicates_in_input(untokenized_dataset): indices_to_keep = [] id_to_idx = {} outputs = [] for i, (id_, output) in enumerate(zip(untokenized_dataset["id"], untokenized_dataset["output"])): if id_ in id_to_idx: outputs[id_to_idx[id_]].append(output) continue indices_to_keep.append(i) id_to_idx[id_] = len(outputs) outputs.append([output]) untokenized_dataset = untokenized_dataset.select(indices_to_keep).flatten_indices() untokenized_dataset = untokenized_dataset.remove_columns("output") untokenized_dataset = untokenized_dataset.add_column("outputs", outputs) return untokenized_dataset