|
import os |
|
import json |
|
import pandas as pd |
|
import yaml |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
from inference import get_latest_checkpoint |
|
|
|
|
|
def process_loss(loss, final_loss): |
|
epoch = int(loss["epoch"]) |
|
final_loss["epoch"].append(epoch) |
|
for key in ["loss", "eval_loss", "eval_rouge1", "eval_rouge2"]: |
|
try: |
|
value = loss[key] |
|
final_loss[key].append(value) |
|
except KeyError: |
|
pass |
|
|
|
def loss_function(losses): |
|
final_loss = { |
|
"epoch": [], |
|
"loss": [], |
|
"eval_loss": [], |
|
"eval_rouge1": [], |
|
"eval_rouge2": [] |
|
} |
|
for loss_steps in losses: |
|
if float(loss_steps.get("epoch", 0)) % 1 == 0: |
|
process_loss(loss_steps, final_loss) |
|
final_loss["epoch"] = list(set(final_loss["epoch"])) |
|
return final_loss |
|
|
|
def plot_loss(data, output_dir): |
|
df = pd.DataFrame(data) |
|
df_melted = pd.melt(df, id_vars=['epoch'], var_name='metric', value_name='value') |
|
plt.figure(figsize=(10, 6)) |
|
sns.lineplot(data=df_melted, x='epoch', y='value', hue='metric', marker='o') |
|
plt.legend(title='Metric') |
|
plt.xlabel('Epoch') |
|
plt.ylabel('Value') |
|
plt.title('Metrics vs Epoch') |
|
plt.savefig(os.path.join(output_dir, 'metrics_vs_epoch.png')) |
|
|
|
|
|
if __name__ == "__main__": |
|
config = yaml.safe_load(open("config.yaml", "r")) |
|
PROJECT_DIR = eval(config["SENTENCE_COMPRESSION"]["PROJECT_DIR"]) |
|
checkpoint_dir = config["SENTENCE_COMPRESSION"]["INFERENCE"]["MODEL_PATH"] |
|
latest_checkpoint = get_latest_checkpoint(os.path.join(PROJECT_DIR, checkpoint_dir)) |
|
logfile_dir = os.path.join(PROJECT_DIR, checkpoint_dir, latest_checkpoint) |
|
logfile_path = os.path.join(logfile_dir, "trainer_state.json") |
|
logs = json.load(open(logfile_path)) |
|
final_loss = loss_function(logs["log_history"]) |
|
output_dir = config["SENTENCE_COMPRESSION"]["OUTPUT"]["RESULT"] |
|
os.makedirs(output_dir, exist_ok=True) |
|
plot_loss(final_loss, output_dir) |