Diagnosis of Acute Coronary Syndrome using ECG waveforms: A machine learning framework and benchmark dataset
Code repository: https://github.com/alexmschubert/ACS-BenchWork
Project description
In this project, we benchmarked various machine learning (ML) approaches for detecting acute coronary syndrome (ACS) - commonly known as 'heart attack' - from 12-lead ECG waveforms. Our findings reveal that ML models can successfully identify large groups of high-risk patients who show no classical ECG features (e.g., ST-elevation or depression) typically associated with ACS by cardiologists. We are releasing the weights of our best-performing models to facilitate further development in data-driven ACS detection. The training dataset, ACS-BenchWork, is available on Nightingale Open Science, and we invite researchers to build on these resources to improve the accuracy and utility of ECG-based ACS screening tools.
Models
We are releasing the weights for the three best-performing models:
- S4-ECG: A Structured State Space model well-suited to ECG waveforms, capturing both local and long-range temporal features for robust ACS detection. Building upon work by Strodthoff et al..
- ResNet-18: A 1D convolutional residual network adapted for ECG data, leveraging skip connections to learn complex patterns without vanishing gradients.
- HuBERT-ECG: A transformer-based architecture originally designed for speech recognition, repurposed and fine-tuned on ECG signals for ACS prediction. Based on work by Coppola et al..
Usage
Below are sample snippets for loading each model with its pretrained weights. Once loaded, you can run inference by passing your pre-processed ECG data to the model’s forward method. Please refer to the project’s GitHub repository for end-to-end examples demonstrating how to use these models.
S4-ECG
import torch
import lightning.pytorch as pl
from src.lightning import S4Model
def load_from_checkpoint(pl_model, checkpoint_path):
""" load from checkpoint function that is compatible with S4
"""
lightning_state_dict = torch.load(checkpoint_path)
state_dict = lightning_state_dict["state_dict"]
for name, param in pl_model.named_parameters():
param.data = state_dict[name].data
for name, param in pl_model.named_buffers():
param.data = state_dict[name].data
checkpoint_path = "path/to/your/benchmark_acs_state_v0/dmaxlwcg/checkpoints/epoch=49-step=1100.ckpt"
model = S4Model(init_lr=1e-4,
d_input=3,
d_output=1)
load_from_checkpoint(model, checkpoint_path)
ResNet-18
import torch
import lightning.pytorch as pl
from src.lightning import ResNet18_1D
checkpoint_path = "path/to/your/benchmark_acs_resnet18_1d_final_vf/2vud5fft/checkpoints/epoch=37-step=418.ckpt"
model = ResNet18_1D.load_from_checkpoint(checkpoint_path)
HuBERT-ECG
import torch
from hubert_ecg import HuBERTECG, HuBERTECGConfig
from hubert_ecg_classification import HuBERTForECGClassification
path = "path/to/your/hubert_3_iteration_300_finetuned_simdmsnv.pt"
checkpoint = torch.load(path, map_location='cpu')
config = checkpoint['model_config']
hubert_ecg = HuBERTECG(config)
hubert_ecg = HuBERTForECGClassification(hubert_ecg)
hubert_ecg.load_state_dict(checkpoint['model_state_dict'])
Citation
Paper currently under review