Spaces:
Runtime error
Runtime error
import json | |
from tqdm import tqdm | |
import config | |
from api_wrappers import hf_data_loader | |
from generation_steps import synthetic_forward | |
def transform(df): | |
print("Generating data for labeling:") | |
synthetic_forward.print_config() | |
tqdm.pandas() | |
manual_df = hf_data_loader.load_raw_rewriting_as_pandas() | |
manual_df = manual_df.sample(frac=1, random_state=config.RANDOM_STATE).set_index(["hash", "repo"])[ | |
["commit_msg_start", "commit_msg_end"] | |
] | |
manual_df = manual_df[~manual_df.index.duplicated(keep="first")] | |
def get_is_manually_rewritten(row): | |
commit_id = (row["hash"], row["repo"]) | |
return commit_id in manual_df.index | |
result = df | |
result["manual_sample"] = result.progress_apply(get_is_manually_rewritten, axis=1) | |
def get_prediction_message(row): | |
commit_id = (row["hash"], row["repo"]) | |
if row["manual_sample"]: | |
return manual_df.loc[commit_id]["commit_msg_start"] | |
return row["prediction"] | |
def get_enhanced_message(row): | |
commit_id = (row["hash"], row["repo"]) | |
if row["manual_sample"]: | |
return manual_df.loc[commit_id]["commit_msg_end"] | |
return synthetic_forward.generate_end_msg(start_msg=row["prediction"], diff=row["mods"]) | |
result["enhanced"] = result.progress_apply(get_enhanced_message, axis=1) | |
result["prediction"] = result.progress_apply(get_prediction_message, axis=1) | |
result["mods"] = result["mods"].progress_apply(json.dumps) | |
result.to_csv(config.DATA_FOR_LABELING_ARTIFACT) | |
print("Done") | |
return result | |
def main(): | |
synthetic_forward.GENERATION_ATTEMPTS = 3 | |
df = hf_data_loader.load_full_commit_with_predictions_as_pandas() | |
transform(df) | |
if __name__ == "__main__": | |
main() | |