File size: 5,946 Bytes
3bfcf87
 
c430c3c
 
 
e7cdde7
 
 
3bfcf87
 
c430c3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
---
tags:
- astronomy
- multimodal
- classification
datasets:
- MeriDK/AstroM3Processed
- MeriDK/AstroM3Dataset
---

# AstroM3-CLIP-12

AstroM³ is a self-supervised multimodal model for astronomy that integrates time-series photometry, spectra, and metadata into a unified embedding space
for classification and other downstream tasks. AstroM³ is trained on [AstroM3Processed](https://huggingface.co/datasets/MeriDK/AstroM3Processed).
For more details on the AstroM³ architecture, training, and results, please refer to the [paper](https://arxiv.org/abs/2411.08842).

<p align="center">
  <img src="astroclip-architecture.png" width="70%">
  <br />
  <span>
    Figure 1: Overview of the multimodal CLIP framework adapted for astronomy, incorporating three data modalities: photometric time-series, spectra, and metadata. 
    Each modality is processed by a dedicated encoder to create embeddings, which are then mapped into a shared embedding space through projection heads. 
    Pairwise similarity matrices align the embeddings across modalities, and a symmetric cross-entropy loss, computed over these matrices, optimizes the model. 
    The total loss, derived from all pairwise losses, guides the model’s trimodal learning.
  </span>
</p>

To perform inference with AstroM³, install the AstroM3 library from our [GitHub repo](https://github.com/MeriDK/AstroM3).
```sh
git clone https://github.com/MeriDK/AstroM3.git
cd AstroM3
```
Create a virtual environment (tested with Python 3.10.14), then install the required dependencies:
```sh
uv venv venv --python 3.10.14
source venv/bin/activate
uv pip install -r requirements.txt
```

A simple example to get started:
1. Data Loading & Preprocessing
```python
from datasets import load_dataset
from src.data import process_photometry

# Load the test dataset
test_dataset = load_dataset('MeriDK/AstroM3Processed', name='full_42', split='test')

# Process photometry to have a fixed sequence length of 200 (center-cropped)
test_dataset = test_dataset.map(process_photometry, batched=True, fn_kwargs={'seq_len': 200, 'how': 'center'})
test_dataset = test_dataset.with_format('torch')
```
2. Model Loading & Embedding Extraction
```python
import torch
from src.model import AstroM3

# Load the base AstroM3-CLIP model
model = AstroM3.from_pretrained('MeriDK/AstroM3-CLIP')

# Retrieve the first sample (batch size = 1)
sample = test_dataset[0:1]
photometry = sample['photometry']
photometry_mask = sample['photometry_mask']
spectra = sample['spectra']
metadata = sample['metadata']

# Example 1: Generate embeddings when all modalities are present
p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, spectra, metadata)
multimodal_emb = (p_emb + s_emb + m_emb) / 3
print('Multimodal Embedding (All Modalities):', multimodal_emb)

# Example 2: Generate embeddings when the spectra modality is missing
dummy_spectra = torch.zeros_like(spectra)  # Dummy tensor for missing spectra
p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, dummy_spectra, metadata)
multimodal_emb_missing = (p_emb + m_emb) / 2
print('Multimodal Embedding (Spectra Missing):', multimodal_emb_missing)
```
3. Classification Examples
```python
from src.model import AstroM3, Informer, GalSpecNet, MetaModel

# Photometry classification
photo_model = Informer.from_pretrained('MeriDK/AstroM3-CLIP-photo')
prediction = photo_model(photometry, photometry_mask).argmax(dim=1).item()
print('Photometry Classification:', test_dataset.features['label'].int2str(prediction))

# Spectra classification
spectra_model = GalSpecNet.from_pretrained('MeriDK/AstroM3-CLIP-spectra')
prediction = spectra_model(spectra).argmax(dim=1).item()
print('Spectra Classification:', test_dataset.features['label'].int2str(prediction))

# Metadata classification
meta_model = MetaModel.from_pretrained('MeriDK/AstroM3-CLIP-meta')
prediction = meta_model(metadata).argmax(dim=1).item()
print('Metadata Classification:', test_dataset.features['label'].int2str(prediction))

# Multimodal classification
all_model = AstroM3.from_pretrained('MeriDK/AstroM3-CLIP-all')
prediction = all_model(photometry, photometry_mask, spectra, metadata).argmax(dim=1).item()
print('Multimodal Classification:', test_dataset.features['label'].int2str(prediction))
```

## The AstroM³ Family

| # Model | # Description |
| :--- | :--- |
| [AstroM3-CLIP](https://huggingface.co/MeriDK/AstroM3-CLIP) | The base model pre-trained using the trimodal CLIP approach. |
| [AstroM3-CLIP-meta](https://huggingface.co/MeriDK/AstroM3-CLIP-meta) | Fine-tuned for metadata-only classification. |
| [AstroM3-CLIP-spectra](https://huggingface.co/MeriDK/AstroM3-CLIP-spectra) | Fine-tuned for spectra-only classification. |
| [AstroM3-CLIP-photo](https://huggingface.co/MeriDK/AstroM3-CLIP-photo) | Fine-tuned for photometry-only classification. |
| [AstroM3-CLIP-all](https://huggingface.co/MeriDK/AstroM3-CLIP-all) | Fine-tuned for multimodal classification. |

## AstroM3-CLIP Variants
These variants of the base AstroM3-CLIP model are trained using different random seeds (42, 0, 66, 12, 123); 
ensure that the dataset is loaded with the corresponding seed for consistency.

| # Model | # Description |
| :--- | :--- |
| [AstroM3-CLIP-42](https://huggingface.co/MeriDK/AstroM3-CLIP-42) | The base model pre-trained with random seed 42 (identical to AstroM3-CLIP). |
| [AstroM3-CLIP-0](https://huggingface.co/MeriDK/AstroM3-CLIP-0) | AstroM3-CLIP pre-trained with random seed 0 (use dataset with seed 0). |
| [AstroM3-CLIP-66](https://huggingface.co/MeriDK/AstroM3-CLIP-66) | AstroM3-CLIP pre-trained with random seed 66 (use dataset with seed 66). |
| [AstroM3-CLIP-12](https://huggingface.co/MeriDK/AstroM3-CLIP-12) | AstroM3-CLIP pre-trained with random seed 12 (use dataset with seed 12). |
| [AstroM3-CLIP-123](https://huggingface.co/MeriDK/AstroM3-CLIP-123) | AstroM3-CLIP pre-trained with random seed 123 (use dataset with seed 123). |