Diagnosis of Acute Coronary Syndrome using ECG waveforms: A machine learning framework and benchmark dataset

Code repository: https://github.com/alexmschubert/ACS-BenchWork

Project description

In this project, we benchmarked various machine learning (ML) approaches for detecting acute coronary syndrome (ACS) - commonly known as 'heart attack' - from 12-lead ECG waveforms. Our findings reveal that ML models can successfully identify large groups of high-risk patients who show no classical ECG features (e.g., ST-elevation or depression) typically associated with ACS by cardiologists. We are releasing the weights of our best-performing models to facilitate further development in data-driven ACS detection. The training dataset, ACS-BenchWork, is available on Nightingale Open Science, and we invite researchers to build on these resources to improve the accuracy and utility of ECG-based ACS screening tools.

Models

We are releasing the weights for the three best-performing models:

  • S4-ECG: A Structured State Space model well-suited to ECG waveforms, capturing both local and long-range temporal features for robust ACS detection. Building upon work by Strodthoff et al..
  • ResNet-18: A 1D convolutional residual network adapted for ECG data, leveraging skip connections to learn complex patterns without vanishing gradients.
  • HuBERT-ECG: A transformer-based architecture originally designed for speech recognition, repurposed and fine-tuned on ECG signals for ACS prediction. Based on work by Coppola et al..

Usage

Below are sample snippets for loading each model with its pretrained weights. Once loaded, you can run inference by passing your pre-processed ECG data to the model’s forward method. Please refer to the project’s GitHub repository for end-to-end examples demonstrating how to use these models.

S4-ECG

import torch
import lightning.pytorch as pl
from src.lightning import S4Model

def load_from_checkpoint(pl_model, checkpoint_path):
    """ load from checkpoint function that is compatible with S4
    """
    lightning_state_dict = torch.load(checkpoint_path)
    state_dict = lightning_state_dict["state_dict"]
    
    for name, param in pl_model.named_parameters():
        param.data = state_dict[name].data
    for name, param in pl_model.named_buffers():
        param.data = state_dict[name].data

checkpoint_path = "path/to/your/benchmark_acs_state_v0/dmaxlwcg/checkpoints/epoch=49-step=1100.ckpt"

model = S4Model(init_lr=1e-4,
                d_input=3,
                d_output=1)

load_from_checkpoint(model, checkpoint_path)

ResNet-18

import torch
import lightning.pytorch as pl
from src.lightning import ResNet18_1D

checkpoint_path = "path/to/your/benchmark_acs_resnet18_1d_final_vf/2vud5fft/checkpoints/epoch=37-step=418.ckpt"
model = ResNet18_1D.load_from_checkpoint(checkpoint_path)

HuBERT-ECG

import torch
from hubert_ecg import HuBERTECG, HuBERTECGConfig
from hubert_ecg_classification import HuBERTForECGClassification

path = "path/to/your/hubert_3_iteration_300_finetuned_simdmsnv.pt"
checkpoint = torch.load(path, map_location='cpu')
config = checkpoint['model_config']
hubert_ecg = HuBERTECG(config)
hubert_ecg = HuBERTForECGClassification(hubert_ecg)
hubert_ecg.load_state_dict(checkpoint['model_state_dict'])

Citation

Paper currently under review

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and HF Inference API was unable to determine this model's library.