ai4anshu's picture
Upload 8 files
d4caa5c verified
import os
import json
import pandas as pd
import yaml
import seaborn as sns
import matplotlib.pyplot as plt
from inference import get_latest_checkpoint
def process_loss(loss, final_loss):
epoch = int(loss["epoch"])
final_loss["epoch"].append(epoch)
for key in ["loss", "eval_loss", "eval_rouge1", "eval_rouge2"]:
try:
value = loss[key]
final_loss[key].append(value)
except KeyError:
pass
def loss_function(losses):
final_loss = {
"epoch": [],
"loss": [],
"eval_loss": [],
"eval_rouge1": [],
"eval_rouge2": []
}
for loss_steps in losses:
if float(loss_steps.get("epoch", 0)) % 1 == 0:
process_loss(loss_steps, final_loss)
final_loss["epoch"] = list(set(final_loss["epoch"]))
return final_loss
def plot_loss(data, output_dir):
df = pd.DataFrame(data)
df_melted = pd.melt(df, id_vars=['epoch'], var_name='metric', value_name='value')
plt.figure(figsize=(10, 6))
sns.lineplot(data=df_melted, x='epoch', y='value', hue='metric', marker='o')
plt.legend(title='Metric')
plt.xlabel('Epoch')
plt.ylabel('Value')
plt.title('Metrics vs Epoch')
plt.savefig(os.path.join(output_dir, 'metrics_vs_epoch.png'))
if __name__ == "__main__":
config = yaml.safe_load(open("config.yaml", "r"))
PROJECT_DIR = eval(config["SENTENCE_COMPRESSION"]["PROJECT_DIR"])
checkpoint_dir = config["SENTENCE_COMPRESSION"]["INFERENCE"]["MODEL_PATH"]
latest_checkpoint = get_latest_checkpoint(os.path.join(PROJECT_DIR, checkpoint_dir))
logfile_dir = os.path.join(PROJECT_DIR, checkpoint_dir, latest_checkpoint)
logfile_path = os.path.join(logfile_dir, "trainer_state.json")
logs = json.load(open(logfile_path))
final_loss = loss_function(logs["log_history"])
output_dir = config["SENTENCE_COMPRESSION"]["OUTPUT"]["RESULT"]
os.makedirs(output_dir, exist_ok=True)
plot_loss(final_loss, output_dir)