import json from tqdm import tqdm import config from api_wrappers import hf_data_loader from generation_steps import synthetic_forward def transform(df): print("Generating data for labeling:") synthetic_forward.print_config() tqdm.pandas() manual_df = hf_data_loader.load_raw_rewriting_as_pandas() manual_df = manual_df.sample(frac=1, random_state=config.RANDOM_STATE).set_index(["hash", "repo"])[ ["commit_msg_start", "commit_msg_end"] ] manual_df = manual_df[~manual_df.index.duplicated(keep="first")] def get_is_manually_rewritten(row): commit_id = (row["hash"], row["repo"]) return commit_id in manual_df.index result = df result["manual_sample"] = result.progress_apply(get_is_manually_rewritten, axis=1) def get_prediction_message(row): commit_id = (row["hash"], row["repo"]) if row["manual_sample"]: return manual_df.loc[commit_id]["commit_msg_start"] return row["prediction"] def get_enhanced_message(row): commit_id = (row["hash"], row["repo"]) if row["manual_sample"]: return manual_df.loc[commit_id]["commit_msg_end"] return synthetic_forward.generate_end_msg(start_msg=row["prediction"], diff=row["mods"]) result["enhanced"] = result.progress_apply(get_enhanced_message, axis=1) result["prediction"] = result.progress_apply(get_prediction_message, axis=1) result["mods"] = result["mods"].progress_apply(json.dumps) result.to_csv(config.DATA_FOR_LABELING_ARTIFACT) print("Done") return result def main(): synthetic_forward.GENERATION_ATTEMPTS = 3 df = hf_data_loader.load_full_commit_with_predictions_as_pandas() transform(df) if __name__ == "__main__": main()