|
--- |
|
tags: |
|
- astronomy |
|
- multimodal |
|
- classification |
|
datasets: |
|
- MeriDK/AstroM3Processed |
|
- MeriDK/AstroM3Dataset |
|
--- |
|
|
|
# AstroM3-CLIP-photo |
|
|
|
AstroM³ is a self-supervised multimodal model for astronomy that integrates time-series photometry, spectra, and metadata into a unified embedding space |
|
for classification and other downstream tasks. AstroM³ is trained on [AstroM3Processed](https://huggingface.co/datasets/MeriDK/AstroM3Processed). |
|
For more details on the AstroM³ architecture, training, and results, please refer to the [paper](https://arxiv.org/abs/2411.08842). |
|
|
|
<p align="center"> |
|
<img src="astroclip-architecture.png" width="70%"> |
|
<br /> |
|
<span> |
|
Figure 1: Overview of the multimodal CLIP framework adapted for astronomy, incorporating three data modalities: photometric time-series, spectra, and metadata. |
|
Each modality is processed by a dedicated encoder to create embeddings, which are then mapped into a shared embedding space through projection heads. |
|
Pairwise similarity matrices align the embeddings across modalities, and a symmetric cross-entropy loss, computed over these matrices, optimizes the model. |
|
The total loss, derived from all pairwise losses, guides the model’s trimodal learning. |
|
</span> |
|
</p> |
|
|
|
To perform inference with AstroM³, install the AstroM3 library from our [GitHub repo](https://github.com/MeriDK/AstroM3). |
|
```sh |
|
git clone https://github.com/MeriDK/AstroM3.git |
|
cd AstroM3 |
|
``` |
|
Create a virtual environment (tested with Python 3.10.14), then install the required dependencies: |
|
```sh |
|
uv venv venv --python 3.10.14 |
|
source venv/bin/activate |
|
uv pip install -r requirements.txt |
|
``` |
|
|
|
A simple example to get started: |
|
1. Data Loading & Preprocessing |
|
```python |
|
from datasets import load_dataset |
|
from src.data import process_photometry |
|
|
|
# Load the test dataset |
|
test_dataset = load_dataset('MeriDK/AstroM3Processed', name='full_42', split='test') |
|
|
|
# Process photometry to have a fixed sequence length of 200 (center-cropped) |
|
test_dataset = test_dataset.map(process_photometry, batched=True, fn_kwargs={'seq_len': 200, 'how': 'center'}) |
|
test_dataset = test_dataset.with_format('torch') |
|
``` |
|
2. Model Loading & Embedding Extraction |
|
```python |
|
import torch |
|
from src.model import AstroM3 |
|
|
|
# Load the base AstroM3-CLIP model |
|
model = AstroM3.from_pretrained('MeriDK/AstroM3-CLIP') |
|
|
|
# Retrieve the first sample (batch size = 1) |
|
sample = test_dataset[0:1] |
|
photometry = sample['photometry'] |
|
photometry_mask = sample['photometry_mask'] |
|
spectra = sample['spectra'] |
|
metadata = sample['metadata'] |
|
|
|
# Example 1: Generate embeddings when all modalities are present |
|
p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, spectra, metadata) |
|
multimodal_emb = (p_emb + s_emb + m_emb) / 3 |
|
print('Multimodal Embedding (All Modalities):', multimodal_emb) |
|
|
|
# Example 2: Generate embeddings when the spectra modality is missing |
|
dummy_spectra = torch.zeros_like(spectra) # Dummy tensor for missing spectra |
|
p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, dummy_spectra, metadata) |
|
multimodal_emb_missing = (p_emb + m_emb) / 2 |
|
print('Multimodal Embedding (Spectra Missing):', multimodal_emb_missing) |
|
``` |
|
3. Classification Examples |
|
```python |
|
from src.model import AstroM3, Informer, GalSpecNet, MetaModel |
|
|
|
# Photometry classification |
|
photo_model = Informer.from_pretrained('MeriDK/AstroM3-CLIP-photo') |
|
prediction = photo_model(photometry, photometry_mask).argmax(dim=1).item() |
|
print('Photometry Classification:', test_dataset.features['label'].int2str(prediction)) |
|
|
|
# Spectra classification |
|
spectra_model = GalSpecNet.from_pretrained('MeriDK/AstroM3-CLIP-spectra') |
|
prediction = spectra_model(spectra).argmax(dim=1).item() |
|
print('Spectra Classification:', test_dataset.features['label'].int2str(prediction)) |
|
|
|
# Metadata classification |
|
meta_model = MetaModel.from_pretrained('MeriDK/AstroM3-CLIP-meta') |
|
prediction = meta_model(metadata).argmax(dim=1).item() |
|
print('Metadata Classification:', test_dataset.features['label'].int2str(prediction)) |
|
|
|
# Multimodal classification |
|
all_model = AstroM3.from_pretrained('MeriDK/AstroM3-CLIP-all') |
|
prediction = all_model(photometry, photometry_mask, spectra, metadata).argmax(dim=1).item() |
|
print('Multimodal Classification:', test_dataset.features['label'].int2str(prediction)) |
|
``` |
|
|
|
## The AstroM³ Family |
|
|
|
| # Model | # Description | |
|
| :--- | :--- | |
|
| [AstroM3-CLIP](https://huggingface.co/MeriDK/AstroM3-CLIP) | The base model pre-trained using the trimodal CLIP approach. | |
|
| [AstroM3-CLIP-meta](https://huggingface.co/MeriDK/AstroM3-CLIP-meta) | Fine-tuned for metadata-only classification. | |
|
| [AstroM3-CLIP-spectra](https://huggingface.co/MeriDK/AstroM3-CLIP-spectra) | Fine-tuned for spectra-only classification. | |
|
| [AstroM3-CLIP-photo](https://huggingface.co/MeriDK/AstroM3-CLIP-photo) | Fine-tuned for photometry-only classification. | |
|
| [AstroM3-CLIP-all](https://huggingface.co/MeriDK/AstroM3-CLIP-all) | Fine-tuned for multimodal classification. | |
|
|
|
## AstroM3-CLIP Variants |
|
These variants of the base AstroM3-CLIP model are trained using different random seeds (42, 0, 66, 12, 123); |
|
ensure that the dataset is loaded with the corresponding seed for consistency. |
|
|
|
| # Model | # Description | |
|
| :--- | :--- | |
|
| [AstroM3-CLIP-42](https://huggingface.co/MeriDK/AstroM3-CLIP-42) | The base model pre-trained with random seed 42 (identical to AstroM3-CLIP). | |
|
| [AstroM3-CLIP-0](https://huggingface.co/MeriDK/AstroM3-CLIP-0) | AstroM3-CLIP pre-trained with random seed 0 (use dataset with seed 0). | |
|
| [AstroM3-CLIP-66](https://huggingface.co/MeriDK/AstroM3-CLIP-66) | AstroM3-CLIP pre-trained with random seed 66 (use dataset with seed 66). | |
|
| [AstroM3-CLIP-12](https://huggingface.co/MeriDK/AstroM3-CLIP-12) | AstroM3-CLIP pre-trained with random seed 12 (use dataset with seed 12). | |
|
| [AstroM3-CLIP-123](https://huggingface.co/MeriDK/AstroM3-CLIP-123) | AstroM3-CLIP pre-trained with random seed 123 (use dataset with seed 123). | |