{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PrithviWxC Rollout Inference\n", "If you haven't already, take a look at the exmaple for the PrithviWxC core\n", "model, as we will pass over the points covered there.\n", "\n", "Here we will introduce the PrithviWxC model that was trained furhter for\n", "autoregressive rollout, a common strategy to increase accuracy and stability of\n", "models when applied to forecasting-type tasks. " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import random\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "from huggingface_hub import hf_hub_download, snapshot_download\n", "\n", "# Set backend etc.\n", "torch.jit.enable_onednn_fusion(True)\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.benchmark = True\n", " torch.backends.cudnn.deterministic = True\n", "\n", "# Set seeds\n", "random.seed(42)\n", "if torch.cuda.is_available():\n", " torch.cuda.manual_seed(42)\n", "torch.manual_seed(42)\n", "np.random.seed(42)\n", "\n", "# Set device\n", "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", "else:\n", " device = torch.device(\"cpu\")\n", "\n", "# Set variables\n", "surface_vars = [\n", " \"EFLUX\",\n", " \"GWETROOT\",\n", " \"HFLUX\",\n", " \"LAI\",\n", " \"LWGAB\",\n", " \"LWGEM\",\n", " \"LWTUP\",\n", " \"PS\",\n", " \"QV2M\",\n", " \"SLP\",\n", " \"SWGNT\",\n", " \"SWTNT\",\n", " \"T2M\",\n", " \"TQI\",\n", " \"TQL\",\n", " \"TQV\",\n", " \"TS\",\n", " \"U10M\",\n", " \"V10M\",\n", " \"Z0M\",\n", "]\n", "static_surface_vars = [\"FRACI\", \"FRLAND\", \"FROCEAN\", \"PHIS\"]\n", "vertical_vars = [\"CLOUD\", \"H\", \"OMEGA\", \"PL\", \"QI\", \"QL\", \"QV\", \"T\", \"U\", \"V\"]\n", "levels = [\n", " 34.0,\n", " 39.0,\n", " 41.0,\n", " 43.0,\n", " 44.0,\n", " 45.0,\n", " 48.0,\n", " 51.0,\n", " 53.0,\n", " 56.0,\n", " 63.0,\n", " 68.0,\n", " 71.0,\n", " 72.0,\n", "]\n", "padding = {\"level\": [0, 0], \"lat\": [0, -1], \"lon\": [0, 0]}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Lead time\n", "When performing auto-regressive rollout, the intermediate steps require the\n", "static data at those times and---if using `residual=climate`---the intermediate\n", "climatology. We provide a dataloader that extends the MERRA2 loader of the\n", "core model, adding in these additional terms. Further, it return target data for\n", "the intermediate steps if those are required for loss terms. \n", "\n", "The `lead_time` flag still lets the target time for the model, however now it\n", "only a single value and must be a positive integer multiple of the `-input_time`. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "lead_time = 3 # This variable can be change to change the task\n", "input_time = -3 # This variable can be change to change the task" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data file\n", "MERRA-2 data is available from 1980 to the present day,\n", "at 3-hour temporal resolution. The dataloader we have provided\n", "expects the surface data and vertical data to be saved in\n", "separate files, and when provided with the directories, will\n", "search for the relevant data that falls within the provided time range.\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "159bec6eee1846d680fe284324094487", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Fetching 1 files: 0%| | 0/1 [00:00 dict[str, Tensor]:\n", " \"\"\"Prepressing function for MERRA2 Dataset\n", "\n", " Args:\n", " batch (dict): List of training samples, each sample should be a\n", " dictionary with the following keys::\n", "\n", " 'sur_static': Numpy array of shape (3, lat, lon). For each pixel (lat, lon), the first dimension indexes sin(lat), cos(lon), sin(lon).\n", " 'sur_vals': Torch tensor of shape (parameter, time, lat, lon).\n", " 'sur_tars': Torch tensor of shape (parameter, time, lat, lon).\n", " 'ulv_vals': Torch tensor of shape (parameter, level, time, lat, lon).\n", " 'ulv_tars': Torch tensor of shape (parameter, level, time, lat, lon).\n", " 'sur_climate': Torch tensor of shape (parameter, lat, lon)\n", " 'ulv_climate': Torch tensor of shape (parameter, level, lat, lon)\n", " 'lead_time': Integer.\n", " 'input_time': Integer.\n", "\n", " padding: Dictionary with keys 'level', 'lat', 'lon', each of dim 2.\n", "\n", " Returns:\n", " Dictionary with the following keys::\n", "\n", " 'x': [batch, time, parameter, lat, lon]\n", " 'y': [batch, parameter, lat, lon]\n", " 'static': [batch, parameter, lat, lon]\n", " 'lead_time': [batch]\n", " 'input_time': [batch]\n", " 'climate (Optional)': [batch, parameter, lat, lon]\n", "\n", " Note:\n", " Here, for x and y, 'parameter' is [surface parameter, upper level,\n", " parameter x level]. Similarly for the static information we have\n", " [sin(lat), cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod),\n", " ...].\n", " \"\"\" # noqa: E501\n", " b0 = batch[0]\n", " nbatch = len(batch)\n", " data_keys = set(b0.keys())\n", "\n", " essential_keys = {\n", " \"sur_static\",\n", " \"sur_vals\",\n", " \"sur_tars\",\n", " \"ulv_vals\",\n", " \"ulv_tars\",\n", " \"input_time\",\n", " \"lead_time\",\n", " }\n", "\n", " climate_keys = {\n", " \"sur_climate\",\n", " \"ulv_climate\",\n", " }\n", "\n", " all_keys = essential_keys | climate_keys\n", "\n", " if not essential_keys.issubset(data_keys):\n", " raise ValueError(\"Missing essential keys.\")\n", "\n", " if not data_keys.issubset(all_keys):\n", " raise ValueError(\"Unexpected keys in batch.\")\n", "\n", " # Bring all tensors from the batch into a single tensor\n", " upl_x = torch.empty((nbatch, *b0[\"ulv_vals\"].shape))\n", " upl_y = torch.empty((nbatch, *b0[\"ulv_tars\"].shape))\n", "\n", " sur_x = torch.empty((nbatch, *b0[\"sur_vals\"].shape))\n", " sur_y = torch.empty((nbatch, *b0[\"sur_tars\"].shape))\n", "\n", " sur_sta = torch.empty((nbatch, *b0[\"sur_static\"].shape))\n", "\n", " lead_time = torch.empty((nbatch,), dtype=torch.float32)\n", " input_time = torch.empty((nbatch,), dtype=torch.float32)\n", "\n", " for i, rec in enumerate(batch):\n", " sur_x[i] = rec[\"sur_vals\"]\n", " sur_y[i] = rec[\"sur_tars\"]\n", "\n", " upl_x[i] = rec[\"ulv_vals\"]\n", " upl_y[i] = rec[\"ulv_tars\"]\n", "\n", " sur_sta[i] = rec[\"sur_static\"]\n", "\n", " lead_time[i] = rec[\"lead_time\"]\n", " input_time[i] = rec[\"input_time\"]\n", "\n", " return_value = {\n", " \"lead_time\": lead_time,\n", " \"input_time\": input_time,\n", " }\n", "\n", " # Reshape (batch, parameter, level, time, lat, lon) ->\n", " # (batch, time, parameter, level, lat, lon)\n", " upl_x = upl_x.permute((0, 3, 1, 2, 4, 5))\n", " upl_y = upl_y.permute((0, 3, 1, 2, 4, 5))\n", " # Reshape (batch, parameter, time, lat, lon) ->\n", " # (batch, time, parameter, lat, lon)\n", " sur_x = sur_x.permute((0, 2, 1, 3, 4))\n", " sur_y = sur_y.permute((0, 2, 1, 3, 4))\n", "\n", " # Pad\n", " padding_2d = (*padding[\"lon\"], *padding[\"lat\"])\n", "\n", " def pad2d(x):\n", " return torch.nn.functional.pad(x, padding_2d, mode=\"constant\", value=0)\n", "\n", " padding_3d = (*padding[\"lon\"], *padding[\"lat\"], *padding[\"level\"])\n", "\n", " def pad3d(x):\n", " return torch.nn.functional.pad(x, padding_3d, mode=\"constant\", value=0)\n", "\n", " sur_x = pad2d(sur_x).contiguous()\n", " upl_x = pad3d(upl_x).contiguous()\n", " sur_y = pad2d(sur_y).contiguous()\n", " upl_y = pad3d(upl_y).contiguous()\n", " return_value[\"static\"] = pad2d(sur_sta).contiguous()\n", "\n", " # Remove time for targets\n", " upl_y = torch.squeeze(upl_y, 1)\n", " sur_y = torch.squeeze(sur_y, 1)\n", "\n", " # We stack along the combined parameter x level dimension\n", " return_value[\"x\"] = torch.cat(\n", " (sur_x, upl_x.view(*upl_x.shape[:2], -1, *upl_x.shape[4:])), dim=2\n", " )\n", " return_value[\"y\"] = torch.cat(\n", " (sur_y, upl_y.view(upl_y.shape[0], -1, *upl_y.shape[3:])), dim=1\n", " )\n", "\n", " if climate_keys.issubset(data_keys):\n", " sur_climate = torch.empty((nbatch, *b0[\"sur_climate\"].shape))\n", " ulv_climate = torch.empty((nbatch, *b0[\"ulv_climate\"].shape))\n", " for i, rec in enumerate(batch):\n", " sur_climate[i] = rec[\"sur_climate\"]\n", " ulv_climate[i] = rec[\"ulv_climate\"]\n", " sur_climate = pad2d(sur_climate)\n", " ulv_climate = pad3d(ulv_climate)\n", "\n", " return_value[\"climate\"] = torch.cat(\n", " (\n", " sur_climate,\n", " ulv_climate.view(nbatch, -1, *ulv_climate.shape[3:]),\n", " ),\n", " dim=1,\n", " )\n", "\n", " return return_value\n", "\n", "\n", "def input_scalers(\n", " surf_vars: list[str],\n", " vert_vars: list[str],\n", " levels: list[float],\n", " surf_path: str | Path,\n", " vert_path: str | Path,\n", ") -> tuple[Tensor, Tensor]:\n", " \"\"\"Reads the input scalers\n", "\n", " Args:\n", " surf_vars: surface variables to be used.\n", " vert_vars: vertical variables to be used.\n", " levels: MERRA2 levels to use.\n", " surf_path: path to surface scalers file.\n", " vert_path: path to vertical level scalers file.\n", "\n", " Returns:\n", " mu (Tensor): mean values\n", " var (Tensor): varience values\n", " \"\"\"\n", " with h5py.File(Path(surf_path), \"r\", libver=\"latest\") as surf_file:\n", " stats = [x.decode().lower() for x in surf_file[\"statistic\"][()]]\n", " mu_idx = stats.index(\"mu\")\n", " sig_idx = stats.index(\"sigma\")\n", "\n", " s_mu = torch.tensor([surf_file[k][()][mu_idx] for k in surf_vars])\n", " s_sig = torch.tensor([surf_file[k][()][sig_idx] for k in surf_vars])\n", "\n", " with h5py.File(Path(vert_path), \"r\", libver=\"latest\") as vert_file:\n", " stats = [x.decode().lower() for x in vert_file[\"statistic\"][()]]\n", " mu_idx = stats.index(\"mu\")\n", " sig_idx = stats.index(\"sigma\")\n", "\n", " lvl = vert_file[\"lev\"][()]\n", " l_idx = [np.where(lvl == v)[0].item() for v in levels]\n", "\n", " v_mu = np.array([vert_file[k][()][mu_idx, l_idx] for k in vert_vars])\n", " v_sig = np.array([vert_file[k][()][sig_idx, l_idx] for k in vert_vars])\n", "\n", " v_mu = torch.from_numpy(v_mu).view(-1)\n", " v_sig = torch.from_numpy(v_sig).view(-1)\n", "\n", " mu = torch.cat((s_mu, v_mu), dim=0).to(torch.float32)\n", " sig = torch.cat((s_sig, v_sig), dim=0).to(torch.float32).clamp(1e-4, 1e4)\n", " return mu, sig\n", "\n", "\n", "def static_input_scalers(\n", " scalar_path: str | Path, stat_vars: list[str], unscaled_params: int = 7\n", ") -> tuple[Tensor, Tensor]:\n", " scalar_path = Path(scalar_path)\n", "\n", " with h5py.File(scalar_path, \"r\", libver=\"latest\") as scaler_file:\n", " stats = [x.decode().lower() for x in scaler_file[\"statistic\"][()]]\n", " mu_idx = stats.index(\"mu\")\n", " sig_idx = stats.index(\"sigma\")\n", "\n", " mu = torch.tensor([scaler_file[k][()][mu_idx] for k in stat_vars])\n", " sig = torch.tensor([scaler_file[k][()][sig_idx] for k in stat_vars])\n", "\n", " z = torch.zeros(unscaled_params, dtype=mu.dtype, device=mu.device)\n", " o = torch.ones(unscaled_params, dtype=sig.dtype, device=sig.device)\n", " mu = torch.cat((z, mu), dim=0).to(torch.float32)\n", " sig = torch.cat((o, sig), dim=0).to(torch.float32)\n", "\n", " return mu, sig.clamp(1e-4, 1e4)\n", "\n", "\n", "def output_scalers(\n", " surf_vars: list[str],\n", " vert_vars: list[str],\n", " levels: list[float],\n", " surf_path: str | Path,\n", " vert_path: str | Path,\n", ") -> Tensor:\n", " surf_path = Path(surf_path)\n", " vert_path = Path(vert_path)\n", "\n", " with h5py.File(surf_path, \"r\", libver=\"latest\") as surf_file:\n", " svars = torch.tensor([surf_file[k][()] for k in surf_vars])\n", "\n", " with h5py.File(vert_path, \"r\", libver=\"latest\") as vert_file:\n", " lvl = vert_file[\"lev\"][()]\n", " l_idx = [np.where(lvl == v)[0].item() for v in levels]\n", " vvars = np.array([vert_file[k][()][l_idx] for k in vert_vars])\n", " vvars = torch.from_numpy(vvars).view(-1)\n", "\n", " var = torch.cat((svars, vvars), dim=0).to(torch.float32).clamp(1e-7, 1e7)\n", "\n", " return var\n", "\n", "\n", "class SampleSpec:\n", " \"\"\"\n", " A data class to collect the information used to define a sample.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " inputs: tuple[pd.Timestamp, pd.Timestamp],\n", " lead_time: int,\n", " target: pd.Timestamp | list[pd.Timestamp],\n", " ):\n", " \"\"\"\n", " Args:\n", " inputs: Tuple of timestamps. In ascending order.\n", " lead_time: Lead time. In hours.\n", " target: Timestamp of the target. Can be before or after the inputs.\n", " \"\"\"\n", " if not inputs[0] < inputs[1]:\n", " raise ValueError(\n", " \"Timestamps in `inputs` should be in strictly ascending order.\"\n", " )\n", "\n", " self.inputs = inputs\n", " self.input_time = (inputs[1] - inputs[0]).total_seconds() / 3600\n", " self.lead_time = lead_time\n", " self.target = target\n", "\n", " self.times = [*inputs, target]\n", " self.stat_times = [inputs[-1]]\n", "\n", " @property\n", " def climatology_info(self) -> tuple[int, int]:\n", " \"\"\"Get the required climatology info.\n", "\n", " :return: information required to obtain climatology data. Essentially\n", " this is the day of the year and hour of the day of the target\n", " timestamp, with the former restricted to the interval [1, 365].\n", " :rtype: tuple\n", " \"\"\"\n", " return (min(self.target.dayofyear, 365), self.target.hour)\n", "\n", " @property\n", " def year(self) -> int:\n", " return self.inputs[1].year\n", "\n", " @property\n", " def dayofyear(self) -> int:\n", " return self.inputs[1].dayofyear\n", "\n", " @property\n", " def hourofday(self) -> int:\n", " return self.inputs[1].hour\n", "\n", " def _info_str(self) -> str:\n", " iso_8601 = \"%Y-%m-%dT%H:%M:%S\"\n", "\n", " return (\n", " f\"Issue time: {self.inputs[1].strftime(iso_8601)}\\n\"\n", " f\"Lead time: {self.lead_time} hours ahead\\n\"\n", " f\"Input delta: {self.input_time} hours\\n\"\n", " f\"Target time: {self.target.strftime(iso_8601)}\"\n", " )\n", "\n", " @classmethod\n", " def get(cls, timestamp: pd.Timestamp, dt: int, lead_time: int):\n", " \"\"\"Given a timestamp and lead time, generates a SampleSpec object\n", " describing the sample further.\n", "\n", " Args:\n", " timestamp: Timstamp of the sample, Ie this is the larger of the two\n", " input timstamps.\n", " dt: Time between input samples, in hours.\n", " lead_time: Lead time. In hours.\n", "\n", " Returns:\n", " SampleSpec\n", " \"\"\" # noqa: E501\n", " assert dt > 0, \"dt should be possitive\"\n", " lt = pd.to_timedelta(lead_time, unit=\"h\")\n", " dt = pd.to_timedelta(dt, unit=\"h\")\n", "\n", " if lead_time >= 0:\n", " timestamp_target = timestamp + lt\n", " else:\n", " timestamp_target = timestamp - dt + lt\n", "\n", " spec = cls(\n", " inputs=(timestamp - dt, timestamp),\n", " lead_time=lead_time,\n", " target=timestamp_target,\n", " )\n", "\n", " return spec\n", "\n", " def __repr__(self) -> str:\n", " return self._info_str()\n", "\n", " def __str__(self) -> str:\n", " return self._info_str()\n", "\n", "\n", "class Merra2Dataset(Dataset):\n", " \"\"\"MERRA2 dataset. The dataset unifies surface and vertical data as well as\n", " optional climatology.\n", "\n", " Samples come in the form of a dictionary. Not all keys support all\n", " variables, yet the general ordering of dimensions is\n", " parameter, level, time, lat, lon\n", "\n", " Note:\n", " Data is assumed to be in NetCDF files containing daily data at 3-hourly\n", " intervals. These follow the naming patterns\n", " MERRA2_sfc_YYYYMMHH.nc and MERRA_pres_YYYYMMHH.nc and can be located in\n", " two different locations. Optional climatology data comes from files\n", " climate_surface_doyDOY_hourHOD.nc and\n", " climate_vertical_doyDOY_hourHOD.nc.\n", "\n", "\n", " Note:\n", " `_get_valid_timestamps` assembles a set of all timestamps for which\n", " there is data (with hourly resolutions). The result is stored in\n", " `_valid_timestamps`. `_get_valid_climate_timestamps` does the same with\n", " climatology data and stores it in `_valid_climate_timestamps`.\n", "\n", " Based on this information, `samples` generates a list of valid samples,\n", " stored in `samples`. Here the format is::\n", "\n", " [\n", " [\n", " (timestamp 1, lead time A),\n", " (timestamp 1, lead time B),\n", " (timestamp 1, lead time C),\n", " ],\n", " [\n", " (timestamp 2, lead time D),\n", " (timestamp 2, lead time E),\n", " ]\n", " ]\n", "\n", " That is, the outer list iterates over timestamps (init times), the\n", " inner over lead times. Only valid entries are stored.\n", " \"\"\"\n", "\n", " valid_vertical_vars = [\n", " \"CLOUD\",\n", " \"H\",\n", " \"OMEGA\",\n", " \"PL\",\n", " \"QI\",\n", " \"QL\",\n", " \"QV\",\n", " \"T\",\n", " \"U\",\n", " \"V\",\n", " ]\n", " valid_surface_vars = [\n", " \"EFLUX\",\n", " \"GWETROOT\",\n", " \"HFLUX\",\n", " \"LAI\",\n", " \"LWGAB\",\n", " \"LWGEM\",\n", " \"LWTUP\",\n", " \"PRECTOT\",\n", " \"PS\",\n", " \"QV2M\",\n", " \"SLP\",\n", " \"SWGNT\",\n", " \"SWTNT\",\n", " \"T2M\",\n", " \"TQI\",\n", " \"TQL\",\n", " \"TQV\",\n", " \"TS\",\n", " \"U10M\",\n", " \"V10M\",\n", " \"Z0M\",\n", " ]\n", " valid_static_surface_vars = [\"FRACI\", \"FRLAND\", \"FROCEAN\", \"PHIS\"]\n", "\n", " valid_levels = [\n", " 34.0,\n", " 39.0,\n", " 41.0,\n", " 43.0,\n", " 44.0,\n", " 45.0,\n", " 48.0,\n", " 51.0,\n", " 53.0,\n", " 56.0,\n", " 63.0,\n", " 68.0,\n", " 71.0,\n", " 72.0,\n", " ]\n", "\n", " timedelta_input = pd.to_timedelta(3, unit=\"h\")\n", "\n", " def __init__(\n", " self,\n", " time_range: tuple[str | pd.Timestamp, str | pd.Timestamp],\n", " lead_times: list[int],\n", " input_times: list[int],\n", " data_path_surface: str | Path,\n", " data_path_vertical: str | Path,\n", " climatology_path_surface: str | Path | None = None,\n", " climatology_path_vertical: str | Path | None = None,\n", " surface_vars: list[str] | None = None,\n", " static_surface_vars: list[str] | None = None,\n", " vertical_vars: list[str] | None = None,\n", " levels: list[float] | None = None,\n", " roll_longitudes: int = 0,\n", " positional_encoding: str = \"absolute\",\n", " rtype: type = np.float32,\n", " dtype: torch.dtype = torch.float32,\n", " ) -> None:\n", " \"\"\"\n", " Args:\n", " data_path_surface: Location of surface data.\n", " data_path_vertical: Location of vertical data.\n", " climatology_path_surface: Location of (optional) surface\n", " climatology.\n", " climatology_path_vertical: Location of (optional) vertical\n", " climatology.\n", " surface_vars: Surface variables.\n", " static_surface_vars: Static surface variables.\n", " vertical_vars: Vertical variables.\n", " levels: Levels.\n", " time_range: Used to subset data.\n", " lead_times: Lead times for generalized forecasting.\n", " roll_longitudes: Set to non-zero value to data by random amount\n", " along longitude dimension.\n", " position_encoding: possible values are\n", " ['absolute' (default), 'fourier'].\n", " 'absolute' returns lat lon encoded in 3 dimensions using sine\n", " and cosine\n", " 'fourier' returns lat/lon to be encoded by model\n", " returns lat/lon to be encoded by model\n", " rtype: numpy data type used during read\n", " dtype: torch data type of data output\n", " \"\"\"\n", "\n", " self.time_range = (\n", " pd.to_datetime(time_range[0]),\n", " pd.to_datetime(time_range[1]),\n", " )\n", " self.lead_times = lead_times\n", " self.input_times = input_times\n", " self._roll_longitudes = list(range(roll_longitudes + 1))\n", "\n", " self._uvars = vertical_vars or self.valid_vertical_vars\n", " self._level = levels or self.valid_levels\n", " self._svars = surface_vars or self.valid_surface_vars\n", " self._sstat = static_surface_vars or self.valid_static_surface_vars\n", " self._nuvars = len(self._uvars)\n", " self._nlevel = len(self._level)\n", " self._nsvars = len(self._svars)\n", " self._nsstat = len(self._sstat)\n", "\n", " self.rtype = rtype\n", " self.dtype = dtype\n", "\n", " self.positional_encoding = positional_encoding\n", "\n", " self._data_path_surface = Path(data_path_surface)\n", " self._data_path_vertical = Path(data_path_vertical)\n", "\n", " self.dir_exists(self._data_path_surface)\n", " self.dir_exists(self._data_path_vertical)\n", "\n", " self._get_coordinates()\n", "\n", " self._climatology_path_surface = Path(climatology_path_surface) or None\n", " self._climatology_path_vertical = (\n", " Path(climatology_path_vertical) or None\n", " )\n", " self._require_clim = (\n", " self._climatology_path_surface is not None\n", " and self._climatology_path_vertical is not None\n", " )\n", "\n", " if self._require_clim:\n", " self.dir_exists(self._climatology_path_surface)\n", " self.dir_exists(self._climatology_path_vertical)\n", " elif (\n", " climatology_path_surface is None\n", " and climatology_path_vertical is None\n", " ):\n", " self._climatology_path_surface = None\n", " self._climatology_path_vertical = None\n", " else:\n", " raise ValueError(\n", " \"Either both or neither of\"\n", " \"`climatology_path_surface` and\"\n", " \"`climatology_path_vertical` should be None.\"\n", " )\n", "\n", " if not set(self._svars).issubset(set(self.valid_surface_vars)):\n", " raise ValueError(\"Invalid surface variable.\")\n", "\n", " if not set(self._sstat).issubset(set(self.valid_static_surface_vars)):\n", " raise ValueError(\"Invalid static surface variable.\")\n", "\n", " if not set(self._uvars).issubset(set(self.valid_vertical_vars)):\n", " raise ValueError(\"Inalid vertical variable.\")\n", "\n", " if not set(self._level).issubset(set(self.valid_levels)):\n", " raise ValueError(\"Invalid level.\")\n", "\n", " @staticmethod\n", " def dir_exists(path: Path) -> None:\n", " if not path.is_dir():\n", " raise ValueError(f\"Directory {path} does not exist.\")\n", "\n", " @property\n", " def upper_shape(self) -> tuple:\n", " \"\"\"Returns the vertical variables shape\n", " Returns:\n", " tuple: vertical variable shape in the following order::\n", "\n", " [VAR, LEV, TIME, LAT, LON]\n", " \"\"\"\n", " return self._nuvars, self._nlevel, 2, 361, 576\n", "\n", " @property\n", " def surface_shape(self) -> tuple:\n", " \"\"\"Returns the surface variables shape\n", "\n", " Returns:\n", " tuple: surafce shape in the following order::\n", "\n", " [VAR, LEV, TIME, LAT, LON]\n", " \"\"\"\n", " return self._nsvars, 2, 361, 576\n", "\n", " def data_file_surface(self, timestamp: pd.Timestamp) -> Path:\n", " \"\"\"Build the surfcae data file name based on timestamp\n", "\n", " Args:\n", " timestamp: a timestamp\n", "\n", " Returns:\n", " Path: constructed path\n", " \"\"\"\n", " pattern = \"MERRA2_sfc_%Y%m%d.nc\"\n", " data_file = self._data_path_surface / timestamp.strftime(pattern)\n", " return data_file\n", "\n", " def data_file_vertical(self, timestamp: pd.Timestamp) -> Path:\n", " \"\"\"Build the vertical data file name based on timestamp\n", "\n", " Args:\n", " timestamp: a timestamp\n", "\n", " Returns:\n", " Path: constructed path\n", " \"\"\"\n", " pattern = \"MERRA_pres_%Y%m%d.nc\"\n", " data_file = self._data_path_vertical / timestamp.strftime(pattern)\n", " return data_file\n", "\n", " def data_file_surface_climate(\n", " self,\n", " timestamp: pd.Timestamp | None = None,\n", " dayofyear: int | None = None,\n", " hourofday: int | None = None,\n", " ) -> Path:\n", " \"\"\"\n", " Returns the path to a climatology file based either on a timestamp or\n", " the dayofyear / hourofday combination.\n", " Args:\n", " timestamp: A timestamp.\n", " dayofyear: Day of the year. 1 to 366.\n", " hourofday: Hour of the day. 0 to 23.\n", " Returns:\n", " Path: Path to climatology file.\n", " \"\"\"\n", " if timestamp is not None and (\n", " (dayofyear is not None) or (hourofday is not None)\n", " ):\n", " raise ValueError(\n", " \"Provide either timestamp or both dayofyear and hourofday.\"\n", " )\n", "\n", " if timestamp is not None:\n", " dayofyear = min(timestamp.dayofyear, 365)\n", " hourofday = timestamp.hour\n", "\n", " file_name = f\"climate_surface_doy{dayofyear:03}_hour{hourofday:02}.nc\"\n", " data_file = self._climatology_path_surface / file_name\n", " return data_file\n", "\n", " def data_file_vertical_climate(\n", " self,\n", " timestamp: pd.Timestamp | None = None,\n", " dayofyear: int | None = None,\n", " hourofday: int | None = None,\n", " ) -> Path:\n", " \"\"\"Returns the path to a climatology file based either on a timestamp\n", " or the dayofyear / hourofday combination.\n", "\n", " Args:\n", " timestamp: A timestamp. dayofyear: Day of the year. 1 to 366.\n", " hourofday: Hour of the day. 0 to 23.\n", " Returns:\n", " Path: Path to climatology file.\n", " \"\"\"\n", " if timestamp is not None and (\n", " (dayofyear is not None) or (hourofday is not None)\n", " ):\n", " raise ValueError(\n", " \"Provide either timestamp or both dayofyear and hourofday.\"\n", " )\n", "\n", " if timestamp is not None:\n", " dayofyear = min(timestamp.dayofyear, 365)\n", " hourofday = timestamp.hour\n", "\n", " file_name = f\"climate_vertical_doy{dayofyear:03}_hour{hourofday:02}.nc\"\n", " data_file = self._climatology_path_vertical / file_name\n", " return data_file\n", "\n", " def _get_coordinates(self) -> None:\n", " \"\"\"\n", " Obtains the coordiantes (latitudes and longitudes) from a single data\n", " file.\n", " \"\"\"\n", " timestamp = next(iter(self.valid_timestamps))\n", "\n", " file = self.data_file_surface(timestamp)\n", " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", " self.lats = lats = handle[\"lat\"][()].astype(self.rtype)\n", " self.lons = lons = handle[\"lon\"][()].astype(self.rtype)\n", "\n", " deg_to_rad = np.pi / 180\n", " self._embed_lat = np.sin(lats * deg_to_rad).reshape(-1, 1)\n", "\n", " self._embed_lon = np.empty((2, 1, len(lons)), dtype=self.rtype)\n", " self._embed_lon[0, 0] = np.cos(lons * deg_to_rad)\n", " self._embed_lon[1, 0] = np.sin(lons * deg_to_rad)\n", "\n", " @ft.cached_property\n", " def lats(self) -> np.ndarray:\n", " timestamp = next(iter(self.valid_timestamps))\n", "\n", " file = self.data_file_surface(timestamp)\n", " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", " return handle[\"lat\"][()].astype(self.rtype)\n", "\n", " @ft.cached_property\n", " def lons(self) -> np.ndarray:\n", " timestamp = next(iter(self.valid_timestamps))\n", "\n", " file = self.data_file_surface(timestamp)\n", " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", " return handle[\"lon\"][()].astype(self.rtype)\n", "\n", " @ft.cached_property\n", " def position_signal(self) -> np.ndarray:\n", " \"\"\"Generates the \"position signal\" that is part of the static\n", " features.\n", "\n", " Returns:\n", " Tensor: Torch tensor of dimension (parameter, lat, lon) containing\n", " sin(lat), cos(lon), sin(lon).\n", " \"\"\"\n", "\n", " latitudes, longitudes = np.meshgrid(\n", " self.lats, self.lons, indexing=\"ij\"\n", " )\n", "\n", " if self.positional_encoding == \"absolute\":\n", " latitudes = latitudes / 360 * 2.0 * np.pi\n", " longitudes = longitudes / 360 * 2.0 * np.pi\n", " sur_static = np.stack(\n", " [np.sin(latitudes), np.cos(longitudes), np.sin(longitudes)],\n", " axis=0,\n", " )\n", " else:\n", " sur_static = np.stack([latitudes, longitudes], axis=0)\n", "\n", " sur_static = sur_static.astype(self.rtype)\n", "\n", " return sur_static\n", "\n", " @ft.cached_property\n", " def valid_timestamps(self) -> set[pd.Timestamp]:\n", " \"\"\"Generates list of valid timestamps based on available files. Only\n", " timestamps for which both surface and vertical information is available\n", " are considered valid.\n", " Returns:\n", " list: list of timestamps\n", " \"\"\"\n", "\n", " s_glob = self._data_path_surface.glob(\"MERRA2_sfc_????????.nc\")\n", " s_files = [os.path.basename(f) for f in s_glob]\n", " v_glob = self._data_path_surface.glob(\"MERRA_pres_????????.nc\")\n", " v_files = [os.path.basename(f) for f in v_glob]\n", "\n", " s_re = re.compile(r\"MERRA2_sfc_(\\d{8}).nc\\Z\")\n", " v_re = re.compile(r\"MERRA_pres_(\\d{8}).nc\\Z\")\n", " fmt = \"%Y%m%d\"\n", "\n", " s_times = {\n", " (datetime.strptime(m[1], fmt))\n", " for f in s_files\n", " if (m := s_re.match(f))\n", " }\n", " v_times = {\n", " (datetime.strptime(m[1], fmt))\n", " for f in v_files\n", " if (m := v_re.match(f))\n", " }\n", "\n", " times = s_times.intersection(v_times)\n", "\n", " # Each file contains a day at 3 hour intervals\n", " times = {\n", " t + timedelta(hours=i) for i in range(0, 24, 3) for t in times\n", " }\n", "\n", " start_time, end_time = self.time_range\n", " times = {pd.Timestamp(t) for t in times if start_time <= t <= end_time}\n", "\n", " return times\n", "\n", " @ft.cached_property\n", " def valid_climate_timestamps(self) -> set[tuple[int, int]]:\n", " \"\"\"Generates list of \"timestamps\" (dayofyear, hourofday) for which\n", " climatology data is present. Only instances for which surface and\n", " vertical data is available are considered valid.\n", " Returns:\n", " list: List of tuples describing valid climatology instances.\n", " \"\"\"\n", " if not self._require_clim:\n", " return set()\n", "\n", " s_glob = self._climatology_path_surface.glob(\n", " \"climate_surface_doy???_hour??.nc\"\n", " )\n", " s_files = [os.path.basename(f) for f in s_glob]\n", "\n", " v_glob = self._climatology_path_vertical.glob(\n", " \"climate_vertical_doy???_hour??.nc\"\n", " )\n", " v_files = [os.path.basename(f) for f in v_glob]\n", "\n", " s_re = re.compile(r\"climate_surface_doy(\\d{3})_hour(\\d{2}).nc\\Z\")\n", " v_re = re.compile(r\"climate_vertical_doy(\\d{3})_hour(\\d{2}).nc\\Z\")\n", "\n", " s_times = {\n", " (int(m[1]), int(m[2])) for f in s_files if (m := s_re.match(f))\n", " }\n", " v_times = {\n", " (int(m[1]), int(m[2])) for f in v_files if (m := v_re.match(f))\n", " }\n", "\n", " times = s_times.intersection(v_times)\n", "\n", " return times\n", "\n", " def _data_available(self, spec: SampleSpec) -> bool:\n", " \"\"\"\n", " Checks whether data is available for a given SampleSpec object. Does so\n", " using the internal sets with available data previously constructed. Not\n", " by checking the file system.\n", " Args:\n", " spec: SampleSpec object as returned by SampleSpec.get\n", " Returns:\n", " bool: if data is availability.\n", " \"\"\"\n", " valid = set(spec.times).issubset(self.valid_timestamps)\n", "\n", " if self._require_clim:\n", " sci = spec.climatology_info\n", " ci = set(sci) if isinstance(sci, list) else set([sci]) # noqa: C405\n", " valid &= ci.issubset(self.valid_climate_timestamps)\n", "\n", " return valid\n", "\n", " @ft.cached_property\n", " def samples(self) -> list[tuple[pd.Timestamp, int, int]]:\n", " \"\"\"\n", " Generates list of all valid samlpes.\n", " Returns:\n", " list: List of tuples (timestamp, input time, lead time).\n", " \"\"\"\n", " valid_samples = []\n", " dts = [(it, lt) for it in self.input_times for lt in self.lead_times]\n", "\n", " for timestamp in sorted(self.valid_timestamps):\n", " timestamp_samples = []\n", " for it, lt in dts:\n", " spec = SampleSpec.get(timestamp, -it, lt)\n", "\n", " if self._data_available(spec):\n", " timestamp_samples.append((timestamp, it, lt))\n", "\n", " if timestamp_samples:\n", " valid_samples.append(timestamp_samples)\n", "\n", " return valid_samples\n", "\n", " def _to_torch(\n", " self,\n", " data: dict[str, Tensor | list[Tensor]],\n", " dtype: torch.dtype = torch.float32,\n", " ) -> dict[str, Tensor | list[Tensor]]:\n", " out = {}\n", " for k, v in data.items():\n", " if isinstance(v, list):\n", " out[k] = [torch.from_numpy(x).to(dtype) for x in v]\n", " else:\n", " out[k] = torch.from_numpy(v).to(dtype)\n", "\n", " return out\n", "\n", " def _lat_roll(\n", " self, data: dict[str, Tensor | list[Tensor]], n: int\n", " ) -> dict[str, Tensor | list[Tensor]]:\n", " out = {}\n", " for k, v in data.items():\n", " if isinstance(v, list):\n", " out[k] = [torch.roll(x, shifts=n, dims=-1) for x in v]\n", " else:\n", " out[k] = torch.roll(v, shifts=n, dims=-1)\n", "\n", " return out\n", "\n", " def _read_static_data(\n", " self, file: str | Path, doy: int, hod: int\n", " ) -> np.ndarray:\n", " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", " lats_surf = handle[\"lat\"]\n", " lons_surf = handle[\"lon\"]\n", "\n", " nll = (len(lats_surf), len(lons_surf))\n", "\n", " npos = len(self.position_signal)\n", " ntime = 4\n", "\n", " nstat = npos + ntime + self._nsstat\n", " data = np.empty((nstat, *nll), dtype=self.rtype)\n", "\n", " for i, key in enumerate(self._sstat, start=npos + ntime):\n", " data[i] = handle[key][()].astype(dtype=self.rtype)\n", "\n", " # [possition signal], cos(doy), sin(doy), cos(hod), sin(hod)\n", " data[0:npos] = self.position_signal\n", " data[npos + 0] = np.cos(2 * np.pi * doy / 366)\n", " data[npos + 1] = np.sin(2 * np.pi * doy / 366)\n", " data[npos + 2] = np.cos(2 * np.pi * hod / 24)\n", " data[npos + 3] = np.sin(2 * np.pi * hod / 24)\n", "\n", " return data\n", "\n", " def _read_surface(\n", " self, tidx: int, nll: tuple[int, int], handle: h5py.File\n", " ) -> np.ndarray:\n", " data = np.empty((self._nsvars, *nll), dtype=self.rtype)\n", "\n", " for i, key in enumerate(self._svars):\n", " data[i] = handle[key][tidx][()].astype(dtype=self.rtype)\n", "\n", " return data\n", "\n", " def _read_levels(\n", " self, tidx: int, nll: tuple[int, int], handle: h5py.File\n", " ) -> np.ndarray:\n", " lvls = handle[\"lev\"][()]\n", " lidx = self._level_idxs(lvls)\n", "\n", " data = np.empty((self._nuvars, self._nlevel, *nll), dtype=self.rtype)\n", "\n", " for i, key in enumerate(self._uvars):\n", " data[i] = handle[key][tidx, lidx][()].astype(dtype=self.rtype)\n", "\n", " return np.ascontiguousarray(np.flip(data, axis=1))\n", "\n", " def _level_idxs(self, lvls):\n", " lidx = [np.argwhere(lvls == int(lvl)).item() for lvl in self._level]\n", " return sorted(lidx)\n", "\n", " @staticmethod\n", " def _date_to_tidx(date: datetime | pd.Timestamp, handle: h5py.File) -> int:\n", " if isinstance(date, pd.Timestamp):\n", " date = date.to_pydatetime()\n", "\n", " time = handle[\"time\"]\n", "\n", " t0 = time.attrs[\"begin_time\"][()].item()\n", " d0 = f\"{time.attrs['begin_date'][()].item()}\"\n", "\n", " offset = datetime.strptime(d0, \"%Y%m%d\")\n", "\n", " times = [offset + timedelta(minutes=int(t + t0)) for t in time[()]]\n", " return times.index(date)\n", "\n", " def _read_data(\n", " self, file_pair: tuple[str, str], date: datetime\n", " ) -> dict[str, np.ndarray]:\n", " s_file, v_file = file_pair\n", "\n", " with h5py.File(s_file, \"r\", libver=\"latest\") as shandle:\n", " lats_surf = shandle[\"lat\"]\n", " lons_surf = shandle[\"lon\"]\n", "\n", " nll = (len(lats_surf), len(lons_surf))\n", "\n", " tidx = self._date_to_tidx(date, shandle)\n", "\n", " sdata = self._read_surface(tidx, nll, shandle)\n", "\n", " with h5py.File(v_file, \"r\", libver=\"latest\") as vhandle:\n", " lats_vert = vhandle[\"lat\"]\n", " lons_vert = vhandle[\"lon\"]\n", "\n", " nll = (len(lats_vert), len(lons_vert))\n", "\n", " tidx = self._date_to_tidx(date, vhandle)\n", "\n", " vdata = self._read_levels(tidx, nll, vhandle)\n", "\n", " data = {\"vert\": vdata, \"surf\": sdata}\n", "\n", " return data\n", "\n", " def _read_climate(\n", " self, file_pair: tuple[str, str]\n", " ) -> dict[str, np.ndarray]:\n", " s_file, v_file = file_pair\n", "\n", " with h5py.File(s_file, \"r\", libver=\"latest\") as shandle:\n", " lats_surf = shandle[\"lat\"]\n", " lons_surf = shandle[\"lon\"]\n", "\n", " nll = (len(lats_surf), len(lons_surf))\n", "\n", " sdata = np.empty((self._nsvars, *nll), dtype=self.rtype)\n", "\n", " for i, key in enumerate(self._svars):\n", " sdata[i] = shandle[key][()].astype(dtype=self.rtype)\n", "\n", " with h5py.File(v_file, \"r\", libver=\"latest\") as vhandle:\n", " lats_vert = vhandle[\"lat\"]\n", " lons_vert = vhandle[\"lon\"]\n", "\n", " nll = (len(lats_vert), len(lons_vert))\n", "\n", " lvls = vhandle[\"lev\"][()]\n", " lidx = self._level_idxs(lvls)\n", "\n", " vdata = np.empty(\n", " (self._nuvars, self._nlevel, *nll), dtype=self.rtype\n", " )\n", "\n", " for i, key in enumerate(self._uvars):\n", " vdata[i] = vhandle[key][lidx][()].astype(dtype=self.rtype)\n", "\n", " data = {\n", " \"vert\": np.ascontiguousarray(np.flip(vdata, axis=1)),\n", " \"surf\": sdata,\n", " }\n", "\n", " return data\n", "\n", " def get_data_from_sample_spec(\n", " self, spec: SampleSpec\n", " ) -> dict[str, Tensor | int | float]:\n", " \"\"\"Loads and assembles sample data given a SampleSpec object.\n", "\n", " Args:\n", " spec (SampleSpec): Full details regarding the data to be loaded\n", " Returns:\n", " dict: Dictionary with the following keys::\n", "\n", " 'sur_static': Torch tensor of shape [parameter, lat, lon]. For\n", " each pixel (lat, lon), the first 7 dimensions index sin(lat),\n", " cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod).\n", " Where doy is the day of the year [1, 366] and hod the hour of\n", " the day [0, 23].\n", " 'sur_vals': Torch tensor of shape [parameter, time, lat, lon].\n", " 'sur_tars': Torch tensor of shape [parameter, time, lat, lon].\n", " 'ulv_vals': Torch tensor of shape [parameter, level, time, lat, lon].\n", " 'ulv_tars': Torch tensor of shape [parameter, level, time, lat, lon].\n", " 'sur_climate': Torch tensor of shape [parameter, lat, lon].\n", " 'ulv_climate': Torch tensor of shape [paramter, level, lat, lon].\n", " 'lead_time': Float.\n", " 'input_time': Float.\n", "\n", " \"\"\" # noqa: E501\n", "\n", " # We assemble the unique timestamps for which we need data.\n", " vals_required = {*spec.times}\n", " stat_required = {*spec.stat_times}\n", "\n", " # We assemble the unique data files from which we need value data\n", " vals_file_map = defaultdict(list)\n", " for t in vals_required:\n", " data_files = (\n", " self.data_file_surface(t),\n", " self.data_file_vertical(t),\n", " )\n", " vals_file_map[data_files].append(t)\n", "\n", " # We assemble the unique data files from which we need static data\n", " stat_file_map = defaultdict(list)\n", " for t in stat_required:\n", " data_files = (\n", " self.data_file_surface(t),\n", " self.data_file_vertical(t),\n", " )\n", " stat_file_map[data_files].append(t)\n", "\n", " # Load the value data\n", " data = {}\n", " for data_files, times in vals_file_map.items():\n", " for time in times:\n", " data[time] = self._read_data(data_files, time)\n", "\n", " # Combine times\n", " sample_data = {}\n", "\n", " input_upl = np.stack([data[t][\"vert\"] for t in spec.inputs], axis=2)\n", " sample_data[\"ulv_vals\"] = input_upl\n", "\n", " target_upl = data[spec.target][\"vert\"]\n", " sample_data[\"ulv_tars\"] = target_upl[:, :, None]\n", "\n", " input_sur = np.stack([data[t][\"surf\"] for t in spec.inputs], axis=1)\n", " sample_data[\"sur_vals\"] = input_sur\n", "\n", " target_sur = data[spec.target][\"surf\"]\n", " sample_data[\"sur_tars\"] = target_sur[:, None]\n", "\n", " # Load the static data\n", " data_files, times = stat_file_map.popitem()\n", " time = times[0].dayofyear, times[0].hour\n", " sample_data[\"sur_static\"] = self._read_static_data(\n", " data_files[0], *time\n", " )\n", "\n", " # If required load the surface data\n", " if self._require_clim:\n", " ci_year, ci_hour = spec.climatology_info\n", "\n", " surf_file = self.data_file_surface_climate(\n", " dayofyear=ci_year,\n", " hourofday=ci_hour,\n", " )\n", "\n", " vert_file = self.data_file_vertical_climate(\n", " dayofyear=ci_year,\n", " hourofday=ci_hour,\n", " )\n", "\n", " clim_data = self._read_climate((surf_file, vert_file))\n", "\n", " sample_data[\"sur_climate\"] = clim_data[\"surf\"]\n", " sample_data[\"ulv_climate\"] = clim_data[\"vert\"]\n", "\n", " # Move the data from numpy to torch\n", " sample_data = self._to_torch(sample_data, dtype=self.dtype)\n", "\n", " # Optionally roll\n", " if len(self._roll_longitudes) > 0:\n", " roll_by = random.choice(self._roll_longitudes)\n", " sample_data = self._lat_roll(sample_data, roll_by)\n", "\n", " # Now that we have rolled, we can add the static data\n", " sample_data[\"lead_time\"] = spec.lead_time\n", " sample_data[\"input_time\"] = spec.input_time\n", "\n", " return sample_data\n", "\n", " def get_data(\n", " self, timestamp: pd.Timestamp, input_time: int, lead_time: int\n", " ) -> dict[str, Tensor | int]:\n", " \"\"\"\n", " Loads data based on timestamp and lead time.\n", " Args:\n", " timestamp: Timestamp.\n", " input_time: time between input samples.\n", " lead_time: lead time.\n", " Returns:\n", " Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',\n", " 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',\n", " 'lead_time'.\n", " \"\"\"\n", " spec = SampleSpec.get(timestamp, -input_time, lead_time)\n", " sample_data = self.get_data_from_sample_spec(spec)\n", " return sample_data\n", "\n", " def __getitem__(self, idx: int) -> dict[str, Tensor | int]:\n", " \"\"\"\n", " Loads data based on sample index and random choice of sample.\n", " Args:\n", " idx: Sample index.\n", " Returns:\n", " Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',\n", " 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',\n", " 'lead_time', 'input_time'.\n", " \"\"\"\n", " sample_set = self.samples[idx]\n", " timestamp, input_time, lead_time, *nsteps = random.choice(sample_set)\n", " sample_data = self.get_data(timestamp, input_time, lead_time)\n", " return sample_data\n", "\n", " def __len__(self):\n", " return len(self.samples)\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import functools as ft\n", "import random\n", "from collections import defaultdict\n", "from copy import deepcopy\n", "from pathlib import Path\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "from torch import Tensor\n", "\n", "# from PrithviWxC.dataloaders.merra2 import Merra2Dataset, SampleSpec\n", "\n", "\n", "def preproc(\n", " batch: list[dict[str, int | float | Tensor]], padding: dict[tuple[int]]\n", ") -> dict[str, Tensor]:\n", " \"\"\"Prepressing function for MERRA2 Dataset\n", "\n", " Args:\n", " batch (dict): List of training samples, each sample should be a\n", " dictionary with the following keys::\n", "\n", " 'sur_static': Numpy array of shape (3, lat, lon). For each pixel (lat, lon), the first dimension indexes sin(lat), cos(lon), sin(lon).\n", " 'sur_vals': Torch tensor of shape (parameter, time, lat, lon).\n", " 'sur_tars': Torch tensor of shape (parameter, time, lat, lon).\n", " 'ulv_vals': Torch tensor of shape (parameter, level, time, lat, lon).\n", " 'ulv_tars': Torch tensor of shape (parameter, level, time, lat, lon).\n", " 'sur_climate': Torch tensor of shape (nstep, parameter, lat, lon)\n", " 'ulv_climate': Torch tensor of shape (nstep parameter, level, lat, lon)\n", " 'lead_time': Integer.\n", " 'input_time': Interger\n", "\n", " padding: Dictionary with keys 'level', 'lat', 'lon', each of dim 2.\n", "\n", " Returns:\n", " Dictionary with the following keys::\n", "\n", " 'x': [batch, time, parameter, lat, lon]\n", " 'ys': [batch, nsteps, parameter, lat, lon]\n", " 'static': [batch, nstep, parameter, lat, lon]\n", " 'lead_time': [batch]\n", " 'input_time': [batch]\n", " 'climate (Optional)': [batch, nsteps, parameter, lat, lon]\n", "\n", " Note:\n", " Here, for x and ys, 'parameter' is [surface parameter, upper level,\n", " parameter x level]. Similarly for the static information we have\n", " [sin(lat), cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod),\n", " ...].\n", " \"\"\" # noqa: E501\n", "\n", " b0 = batch[0]\n", " nbatch = len(batch)\n", " data_keys = set(b0.keys())\n", "\n", " essential_keys = {\n", " \"sur_static\",\n", " \"sur_vals\",\n", " \"sur_tars\",\n", " \"ulv_vals\",\n", " \"ulv_tars\",\n", " \"input_time\",\n", " \"lead_time\",\n", " }\n", "\n", " climate_keys = {\n", " \"sur_climate\",\n", " \"ulv_climate\",\n", " }\n", "\n", " all_keys = essential_keys | climate_keys\n", "\n", " if not essential_keys.issubset(data_keys):\n", " raise ValueError(\"Missing essential keys.\")\n", "\n", " if not data_keys.issubset(all_keys):\n", " raise ValueError(\"Unexpected keys in batch.\")\n", "\n", " # Bring all tensors from the batch into a single tensor\n", " upl_x = torch.empty((nbatch, *b0[\"ulv_vals\"].shape))\n", " upl_y = torch.empty((nbatch, *b0[\"ulv_tars\"].shape))\n", "\n", " sur_x = torch.empty((nbatch, *b0[\"sur_vals\"].shape))\n", " sur_y = torch.empty((nbatch, *b0[\"sur_tars\"].shape))\n", "\n", " sur_sta = torch.empty((nbatch, *b0[\"sur_static\"].shape))\n", "\n", " lead_time = torch.empty(\n", " (nbatch, *b0[\"lead_time\"].shape),\n", " dtype=torch.float32,\n", " )\n", " input_time = torch.empty((nbatch,), dtype=torch.float32)\n", "\n", " for i, rec in enumerate(batch):\n", " sur_x[i] = torch.Tensor(rec[\"sur_vals\"])\n", " sur_y[i] = torch.Tensor(rec[\"sur_tars\"])\n", "\n", " upl_x[i] = torch.Tensor(rec[\"ulv_vals\"])\n", " upl_y[i] = torch.Tensor(rec[\"ulv_tars\"])\n", "\n", " sur_sta[i] = torch.Tensor(rec[\"sur_static\"])\n", "\n", " lead_time[i] = rec[\"lead_time\"]\n", " input_time[i] = rec[\"input_time\"]\n", "\n", " return_value = {\n", " \"lead_time\": lead_time,\n", " \"input_time\": input_time,\n", " \"target_time\": torch.sum(lead_time).reshape(-1),\n", " }\n", "\n", " # Reshape (batch, parameter, level, time, lat, lon)\n", " # -> (batch, time, parameter, level, lat, lon)\n", " upl_x = upl_x.permute((0, 3, 1, 2, 4, 5))\n", " upl_y = upl_y.permute((0, 3, 1, 2, 4, 5))\n", "\n", " # Reshape (batch, parameter, time, lat, lon)\n", " # -> (batch, time, parameter, lat, lon)\n", " sur_x = sur_x.permute((0, 2, 1, 3, 4))\n", " sur_y = sur_y.permute((0, 2, 1, 3, 4))\n", "\n", " # Pad\n", " padding_2d = (*padding[\"lon\"], *padding[\"lat\"])\n", "\n", " def pad2d(x):\n", " return torch.nn.functional.pad(x, padding_2d, mode=\"constant\", value=0)\n", "\n", " padding_3d = (*padding[\"lon\"], *padding[\"lat\"], *padding[\"level\"])\n", "\n", " def pad3d(x):\n", " return torch.nn.functional.pad(x, padding_3d, mode=\"constant\", value=0)\n", "\n", " sur_x = pad2d(sur_x).contiguous()\n", " upl_x = pad3d(upl_x).contiguous()\n", " sur_y = pad2d(sur_y).contiguous()\n", " upl_y = pad3d(upl_y).contiguous()\n", " return_value[\"statics\"] = pad2d(sur_sta).contiguous()\n", "\n", " # We stack along the combined parameter level dimension\n", " return_value[\"x\"] = torch.cat(\n", " (sur_x, upl_x.view(*upl_x.shape[:2], -1, *upl_x.shape[4:])), dim=2\n", " )\n", " return_value[\"ys\"] = torch.cat(\n", " (sur_y, upl_y.view(*upl_y.shape[:2], -1, *upl_y.shape[4:])), dim=2\n", " )\n", "\n", " if climate_keys.issubset(data_keys):\n", " sur_climate = torch.empty((nbatch, *b0[\"sur_climate\"].shape))\n", " ulv_climate = torch.empty((nbatch, *b0[\"ulv_climate\"].shape))\n", " for i, rec in enumerate(batch):\n", " sur_climate[i] = rec[\"sur_climate\"]\n", " ulv_climate[i] = rec[\"ulv_climate\"]\n", " sur_climate = pad2d(sur_climate)\n", " ulv_climate = pad3d(ulv_climate)\n", "\n", " ulv_climate = ulv_climate.view(\n", " *ulv_climate.shape[:2], -1, *ulv_climate.shape[4:]\n", " )\n", " return_value[\"climates\"] = torch.cat((sur_climate, ulv_climate), dim=2)\n", "\n", " return return_value\n", "\n", "\n", "class RolloutSpec(SampleSpec):\n", " \"\"\"\n", " A data class to collect the information used to define a rollout sample.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " inputs: tuple[pd.Timestamp, pd.Timestamp],\n", " lead_time: int,\n", " target: pd.Timestamp,\n", " ):\n", " \"\"\"\n", " Args:\n", " inputs: Tuple of timestamps. In ascending order.\n", " lead_time: Lead time. In hours.\n", " target: Timestamp of the target. Can be before or after the inputs.\n", " \"\"\"\n", " super().__init__(inputs, lead_time, target)\n", "\n", " self.dt = dt = pd.Timedelta(lead_time, unit=\"h\")\n", " self.inters = list(pd.date_range(inputs[-1], target, freq=dt))\n", "\n", " self._ctimes = deepcopy(self.inters)\n", " self.stat_times = deepcopy(self.inters)\n", "\n", " self.stat_times.pop(-1)\n", " self._ctimes.pop(0)\n", " self.inters.pop(0)\n", " self.inters.pop(-1)\n", "\n", " self.times = [*inputs, *self.inters, target]\n", " self.targets = self.times[2:]\n", " self.nsteps = len(self.times) - 2\n", "\n", " @property\n", " def climatology_info(self) -> dict[pd.Timestamp, tuple[int, int]]:\n", " \"\"\"Returns information required to obtain climatology data.\n", " Returns:\n", " list: list containing required climatology info.\n", " \"\"\"\n", " return [(min(t.dayofyear, 365), t.hour) for t in self._ctimes]\n", "\n", " def _info_str(self) -> str:\n", " iso_8601 = \"%Y-%m-%dT%H:%M:%S\"\n", "\n", " inter_str = \"\\n\".join(t.strftime(iso_8601) for t in self.inters)\n", "\n", " return (\n", " f\"Issue time: {self.inputs[1].strftime(iso_8601)}\\n\"\n", " f\"Lead time: {self.lead_time} hours ahead\\n\"\n", " f\"Target time: {self.target.strftime(iso_8601)}\\n\"\n", " f\"Intermediate times: {inter_str}\"\n", " )\n", "\n", " @classmethod\n", " def get(cls, timestamp: pd.Timestamp, lead_time: int, nsteps: int):\n", " \"\"\"Given a timestamp and lead time, generates a RolloutSpec object\n", " describing the sample further.\n", "\n", " Args:\n", " timestamp: Timstamp (issue time) of the sample.\n", " lead_time: Lead time. In hours.\n", "\n", " Returns:\n", " SampleSpec object.\n", " \"\"\"\n", " if lead_time > 0:\n", " dt = pd.to_timedelta(lead_time, unit=\"h\")\n", " timestamp_target = timestamp + nsteps * dt\n", " else:\n", " raise ValueError(\"Rollout is only forwards\")\n", "\n", " spec = cls(\n", " inputs=(timestamp - dt, timestamp),\n", " lead_time=lead_time,\n", " target=timestamp_target,\n", " )\n", "\n", " return spec\n", "\n", " def __repr__(self) -> str:\n", " return self._info_str()\n", "\n", " def __str__(self) -> str:\n", " return self._info_str()\n", "\n", "\n", "class Merra2RolloutDataset(Merra2Dataset):\n", " \"\"\"Dataset class that read MERRA2 data for performing rollout.\n", "\n", " Implementation details::\n", "\n", " Samples stores the list of valid samples. This takes the form\n", " ```\n", " [\n", " [(timestamp 1, -input_time, n_steps)],\n", " [(timestamp 2, -input_time, n_steps)],\n", " ]\n", " ```\n", " The nested list is for compatibility reasons with Merra2Dataset. Note\n", " that input time and n_steps are always the same value. For some reason\n", " the sign of input_time is the opposite to that in Merra2Dataset\n", " \"\"\"\n", "\n", " input_time_len = 2\n", "\n", " def __init__(\n", " self,\n", " time_range: tuple[str | pd.Timestamp, str | pd.Timestamp],\n", " input_time: int | float | pd.Timedelta,\n", " lead_time: int | float,\n", " data_path_surface: str | Path,\n", " data_path_vertical: str | Path,\n", " climatology_path_surface: str | Path | None,\n", " climatology_path_vertical: str | Path | None,\n", " surface_vars: list[str],\n", " static_surface_vars: list[str],\n", " vertical_vars: list[str],\n", " levels: list[float],\n", " roll_longitudes: int = 0,\n", " positional_encoding: str = \"absolute\",\n", " ):\n", " \"\"\"\n", " Args:\n", " time_range: time range to consider when building dataset\n", " input_time: requested time between inputs\n", " lead_time: requested time to predict\n", " data_path_surface: path of surface data directory\n", " data_path_vertical: path of vertical data directory\n", " climatology_path_surface: path of surface climatology data\n", " directory\n", " climatology_path_vertical: path of vertical climatology data\n", " directory\n", " surface_vars: surface variables to return\n", " static_surface_vars: static surface variables to return\n", " vertical_vars: vertical variables to return\n", " levels: MERA2 vertical levels to consider\n", " roll_longitudes: Whether and now uch to randomly roll latitudes by.\n", " Defaults to 0.\n", " positional_encoding: The type of possitional encodeing to use.\n", " Defaults to \"absolute\".\n", "\n", " Raises:\n", " ValueError: If lead time is not integer multiple of input time\n", " \"\"\"\n", "\n", " self._target_lead = lead_time\n", "\n", " if isinstance(input_time, int) or isinstance(input_time, float):\n", " self.timedelta_input = pd.to_timedelta(-input_time, unit=\"h\")\n", " else:\n", " self.timedelta_input = -input_time\n", "\n", " lead_times = [self.timedelta_input / pd.to_timedelta(1, unit=\"h\")]\n", "\n", " super().__init__(\n", " time_range,\n", " lead_times,\n", " [input_time],\n", " data_path_surface,\n", " data_path_vertical,\n", " climatology_path_surface,\n", " climatology_path_vertical,\n", " surface_vars,\n", " static_surface_vars,\n", " vertical_vars,\n", " levels,\n", " roll_longitudes,\n", " positional_encoding,\n", " )\n", "\n", " nstep_float = (\n", " pd.to_timedelta(self._target_lead, unit=\"h\") / self.timedelta_input\n", " )\n", "\n", " if abs(nstep_float % 1) > 1e-5:\n", " raise ValueError(\"Leadtime not multiple of input time\")\n", "\n", " self.nsteps = round(nstep_float)\n", "\n", " @ft.cached_property\n", " def samples(self) -> list[tuple[pd.Timestamp, int, int]]:\n", " \"\"\"Generates list of all valid samlpes.\n", "\n", " Returns:\n", " List of tuples (timestamp, input time, lead time).\n", " \"\"\"\n", " valid_samples = []\n", "\n", " for timestamp in sorted(self.valid_timestamps):\n", " timestamp_samples = []\n", " for lt in self.lead_times:\n", " spec = RolloutSpec.get(timestamp, lt, self.nsteps)\n", "\n", " if self._data_available(spec):\n", " timestamp_samples.append(\n", " (timestamp, self.input_times[0], lt, self.nsteps)\n", " )\n", "\n", " if timestamp_samples:\n", " valid_samples.append(timestamp_samples)\n", "\n", " return valid_samples\n", "\n", " def get_data_from_rollout_spec(\n", " self, spec: RolloutSpec\n", " ) -> dict[str, Tensor | int | float]:\n", " \"\"\"Loads and assembles sample data given a RolloutSpec object.\n", "\n", " Args:\n", " spec (RolloutSpec): Full details regarding the data to be loaded\n", " Returns:\n", " dict: Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',\n", " 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',c'lead_time',\n", " 'input_time'. For each, the value is as follows::\n", "\n", " {\n", " 'sur_static': Torch tensor of shape [parameter, lat, lon]. For\n", " each pixel (lat, lon), the first 7 dimensions index sin(lat),\n", " cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod).\n", " Where doy is the day of the year [1, 366] and hod the hour of\n", " the day [0, 23].\n", " 'sur_vals': Torch tensor of shape [parameter, time, lat, lon].\n", " 'sur_tars': Torch tensor of shape [parameter, time, lat, lon].\n", " 'ulv_vals': Torch tensor of shape\n", " [parameter, level, time, lat, lon].\n", " 'ulv_tars': Torch tensor of shape\n", " [nsteps, parameter, level, time, lat, lon].\n", " 'sur_climate': Torch tensor of shape\n", " [nsteps, parameter, lat, lon].\n", " 'ulv_climate': Torch tensor of shape\n", " [nsteps, paramter, level, lat, lon].\n", " 'lead_time': Float.\n", " 'input_time': Float.\n", " }\n", "\n", " \"\"\"\n", "\n", " # We assemble the unique timestamps for which we need data.\n", " vals_required = {*spec.times}\n", " stat_required = {*spec.stat_times}\n", "\n", " # We assemble the unique data files from which we need value data\n", " vals_file_map = defaultdict(list)\n", " for t in vals_required:\n", " data_files = (\n", " self.data_file_surface(t),\n", " self.data_file_vertical(t),\n", " )\n", " vals_file_map[data_files].append(t)\n", "\n", " # We assemble the unique data files from which we need static data\n", " stat_file_map = defaultdict(list)\n", " for t in stat_required:\n", " data_files = (\n", " self.data_file_surface(t),\n", " self.data_file_vertical(t),\n", " )\n", " stat_file_map[data_files].append(t)\n", "\n", " # Load the value data\n", " data = {}\n", " for data_files, times in vals_file_map.items():\n", " for time in times:\n", " data[time] = self._read_data(data_files, time)\n", "\n", " # Load the static data\n", " stat = {}\n", " for data_files, times in stat_file_map.items():\n", " for time in times:\n", " hod, doy = time.hour, time.dayofyear\n", " stat[time] = self._read_static_data(data_files[0], hod, doy)\n", "\n", " # Combine times\n", " sample_data = {}\n", "\n", " input_upl = np.stack([data[t][\"vert\"] for t in spec.inputs], axis=2)\n", " sample_data[\"ulv_vals\"] = input_upl\n", "\n", " target_upl = np.stack([data[t][\"vert\"] for t in spec.targets], axis=2)\n", " sample_data[\"ulv_tars\"] = target_upl\n", "\n", " input_sur = np.stack([data[t][\"surf\"] for t in spec.inputs], axis=1)\n", " sample_data[\"sur_vals\"] = input_sur\n", "\n", " target_sur = np.stack([data[t][\"surf\"] for t in spec.targets], axis=1)\n", " sample_data[\"sur_tars\"] = target_sur\n", "\n", " # Load the static data\n", " static = np.stack([stat[t] for t in spec.stat_times], axis=0)\n", " sample_data[\"sur_static\"] = static\n", "\n", " # If required load the climate data\n", " if self._require_clim:\n", " clim_data = {}\n", " for ci in spec.climatology_info:\n", " ci_year, ci_hour = ci\n", "\n", " surf_file = self.data_file_surface_climate(\n", " dayofyear=ci_year,\n", " hourofday=ci_hour,\n", " )\n", "\n", " vert_file = self.data_file_vertical_climate(\n", " dayofyear=ci_year,\n", " hourofday=ci_hour,\n", " )\n", "\n", " clim_data[ci] = self._read_climate((surf_file, vert_file))\n", "\n", " clim_surf = [clim_data[ci][\"surf\"] for ci in spec.climatology_info]\n", " sample_data[\"sur_climate\"] = np.stack(clim_surf, axis=0)\n", "\n", " clim_surf = [clim_data[ci][\"vert\"] for ci in spec.climatology_info]\n", " sample_data[\"ulv_climate\"] = np.stack(clim_surf, axis=0)\n", "\n", " # Move the data from numpy to torch\n", " sample_data = self._to_torch(sample_data, dtype=self.dtype)\n", "\n", " # Optionally roll\n", " if len(self._roll_longitudes) > 0:\n", " roll_by = random.choice(self._roll_longitudes)\n", " sample_data = self._lat_roll(sample_data, roll_by)\n", "\n", " # Now that we have rolled, we can add the static data\n", " lt = torch.tensor([spec.lead_time] * self.nsteps).to(self.dtype)\n", " sample_data[\"lead_time\"] = lt\n", " sample_data[\"input_time\"] = spec.input_time\n", "\n", " return sample_data\n", "\n", " def get_data(\n", " self, timestamp: pd.Timestamp, *args, **kwargs\n", " ) -> dict[Tensor | int]:\n", " \"\"\"Loads data based on timestamp and lead time.\n", "\n", " Args:\n", " timestamp: Timestamp.\n", " Returns:\n", " Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',\n", " 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',\n", " 'lead_time', 'input_time'\n", " \"\"\"\n", " rollout_spec = RolloutSpec.get(\n", " timestamp, self.lead_times[0], self.nsteps\n", " )\n", " sample_data = self.get_data_from_rollout_spec(rollout_spec)\n", " return sample_data\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# from PrithviWxC.dataloaders.merra2_rollout import Merra2RolloutDataset\n", "\n", "dataset = Merra2RolloutDataset(\n", " time_range=time_range,\n", " lead_time=lead_time,\n", " input_time=input_time,\n", " data_path_surface=surf_dir,\n", " data_path_vertical=vert_dir,\n", " climatology_path_surface=surf_clim_dir,\n", " climatology_path_vertical=vert_clim_dir,\n", " surface_vars=surface_vars,\n", " static_surface_vars=static_surface_vars,\n", " vertical_vars=vertical_vars,\n", " levels=levels,\n", " positional_encoding=positional_encoding,\n", ")\n", "assert len(dataset) > 0, \"There doesn't seem to be any valid data.\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model\n", "### Scalers and other hyperparameters\n", "Again, this setup is similar as before." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# from PrithviWxC.dataloaders.merra2 import (\n", "# input_scalers,\n", "# output_scalers,\n", "# static_input_scalers,\n", "# )\n", "\n", "surf_in_scal_path = Path(\"./climatology/musigma_surface.nc\")\n", "hf_hub_download(\n", " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", " filename=f\"climatology/{surf_in_scal_path.name}\",\n", " local_dir=\".\",\n", ")\n", "\n", "vert_in_scal_path = Path(\"./climatology/musigma_vertical.nc\")\n", "hf_hub_download(\n", " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", " filename=f\"climatology/{vert_in_scal_path.name}\",\n", " local_dir=\".\",\n", ")\n", "\n", "surf_out_scal_path = Path(\"./climatology/anomaly_variance_surface.nc\")\n", "hf_hub_download(\n", " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", " filename=f\"climatology/{surf_out_scal_path.name}\",\n", " local_dir=\".\",\n", ")\n", "\n", "vert_out_scal_path = Path(\"./climatology/anomaly_variance_vertical.nc\")\n", "hf_hub_download(\n", " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", " filename=f\"climatology/{vert_out_scal_path.name}\",\n", " local_dir=\".\",\n", ")\n", "\n", "hf_hub_download(\n", " repo_id=\"Prithvi-WxC/prithvi.wxc.rollout.2300m.v1\",\n", " filename=\"config.yaml\",\n", " local_dir=\".\",\n", ")\n", "\n", "in_mu, in_sig = input_scalers(\n", " surface_vars,\n", " vertical_vars,\n", " levels,\n", " surf_in_scal_path,\n", " vert_in_scal_path,\n", ")\n", "\n", "output_sig = output_scalers(\n", " surface_vars,\n", " vertical_vars,\n", " levels,\n", " surf_out_scal_path,\n", " vert_out_scal_path,\n", ")\n", "\n", "static_mu, static_sig = static_input_scalers(\n", " surf_in_scal_path,\n", " static_surface_vars,\n", ")\n", "\n", "residual = \"none\"\n", "masking_mode = \"local\"\n", "decoder_shifting = True\n", "masking_ratio = 0.99" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model init\n", "We can now build and load the pretrained weights, note that you should use the\n", "rollout version of the weights." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'weights\\\\prithvi.wxc.rollout.2300m.v1.pt'" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "weights_path = Path(\"./weights/prithvi.wxc.rollout.2300m.v1.pt\")\n", "hf_hub_download(\n", " repo_id=\"Prithvi-WxC/prithvi.wxc.rollout.2300m.v1\",\n", " filename=weights_path.name,\n", " local_dir=\"./weights\",\n", ")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "from functools import cached_property\n", "from importlib.metadata import version\n", "\n", "from torch import Tensor\n", "from torch.utils.checkpoint import checkpoint\n", "\n", "if version(\"torch\") > \"2.3.0\":\n", " from torch.nn.attention import SDPBackend, sdpa_kernel\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "\n", "# DropPath code is straight from timm\n", "# (https://huggingface.co/spaces/Roll20/pet_score/blame/main/lib/timm/models/layers/drop.py)\n", "def drop_path(\n", " x: Tensor,\n", " drop_prob: float = 0.0,\n", " training: bool = False,\n", " scale_by_keep: bool = True,\n", ") -> Tensor:\n", " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of\n", " residual blocks). Taken form timm.\n", "\n", " Args:\n", " x (Tensor): Input tensor.\n", " drop_prob (float): Probability of dropping `x`, defaults to 0.\n", " training (bool): Whether model is in in traingin of eval mode,\n", " defaults to False.\n", " scale_by_keep (bool): Whether the output should scaled by\n", " (`1 - drop_prob`), defaults to True.\n", " Returns:\n", " Tensor: Tensor that may have randomly dropped with proability\n", " `drop_path`\n", " \"\"\"\n", " if drop_prob == 0.0 or not training:\n", " return x\n", " keep_prob = 1 - drop_prob\n", " shape = (x.shape[0],) + (1,) * (x.ndim - 1)\n", " random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n", " if keep_prob > 0.0 and scale_by_keep:\n", " random_tensor.div_(keep_prob)\n", " return x * random_tensor\n", "\n", "\n", "class DropPath(nn.Module):\n", " \"\"\"\n", " Drop paths (Stochastic Depth) per sample (when applied in main path of\n", " residual blocks).\n", " \"\"\"\n", "\n", " def __init__(\n", " self, drop_prob: float | None = None, scale_by_keep: bool = True\n", " ) -> None:\n", " super(DropPath, self).__init__()\n", " self.drop_prob = drop_prob\n", " self.scale_by_keep = scale_by_keep\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " \"\"\"Runs drop path on input tensor\n", "\n", " Args:\n", " x: input\n", "\n", " Returns:\n", " tensor: output after drop_path\n", " \"\"\"\n", " return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)\n", "\n", "\n", "class Mlp(nn.Module):\n", " \"\"\"\n", " Multi layer perceptron.\n", " \"\"\"\n", "\n", " def __init__(\n", " self, features: int, hidden_features: int, dropout: float = 0.0\n", " ) -> None:\n", " \"\"\"\n", " Args:\n", " features: Input/output dimension.\n", " hidden_features: Hidden dimension.\n", " dropout: Dropout.\n", " \"\"\"\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(features, hidden_features),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(hidden_features, features),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Args:\n", " x (Tesnor): Tensor of shape [..., channel]\n", " Returns:\n", " Tenosr: Tensor of same shape as x.\n", " \"\"\"\n", " return self.net(x)\n", "\n", "\n", "class LayerNormPassThrough(nn.LayerNorm):\n", " \"\"\"Normalising layer that allows the attention mask to be passed through\"\"\"\n", "\n", " def __init__(self, *args, **kwargs):\n", " super().__init__(*args, **kwargs)\n", "\n", " def forward(\n", " self, d: tuple[Tensor, Tensor | None]\n", " ) -> tuple[Tensor, Tensor | None]:\n", " \"\"\"Forwards function\n", "\n", " Args:\n", " d (tuple): tuple of the data tensor and the attention mask\n", " Returns:\n", " output (Tensor): normalised output data\n", " attn_mask (Tensor): the attention mask that was passed in\n", " \"\"\"\n", " input, attn_mask = d\n", " output = F.layer_norm(\n", " input, self.normalized_shape, self.weight, self.bias, self.eps\n", " )\n", " return output, attn_mask\n", "\n", "\n", "class MultiheadAttention(nn.Module):\n", " \"\"\"Multihead attention layer for inputs of shape\n", " [..., sequence, features].\n", " \"\"\"\n", "\n", " def __init__(self, features: int, n_heads: int, dropout: float) -> None:\n", " \"\"\"\n", " Args:\n", " features: Number of features for inputs to the layer.\n", " n_heads: Number of attention heads. Should be a factor of features.\n", " (I.e. the layer uses features // n_heads.)\n", " dropout: Dropout.\n", " \"\"\" # noqa: E501\n", " super().__init__()\n", "\n", " if (features % n_heads) != 0:\n", " raise ValueError(\n", " f\"Features '{features}' is not divisible by heads '{n_heads}'.\"\n", " )\n", "\n", " self.features = features\n", " self.n_heads = n_heads\n", " self.dropout = dropout\n", "\n", " self.qkv_layer = torch.nn.Linear(features, features * 3, bias=False)\n", " self.w_layer = torch.nn.Linear(features, features, bias=False)\n", "\n", " def forward(self, d: tuple[Tensor, Tensor | None]) -> Tensor:\n", " \"\"\"\n", " Args:\n", " d (tuple): tuple containing Tensor of shape [..., sequence, features] and the attention mask\n", " Returns:\n", " Tensor: Tensor of shape [..., sequence, features]\n", " \"\"\" # noqa: E501\n", " x, attn_mask = d\n", "\n", " if not x.shape[-1] == self.features:\n", " raise ValueError(\n", " f\"Expecting tensor with last dimension size {self.features}.\"\n", " )\n", "\n", " passenger_dims = x.shape[:-2]\n", " B = passenger_dims.numel()\n", " S = x.shape[-2]\n", " C = x.shape[-1]\n", " x = x.reshape(B, S, C)\n", "\n", " # x [B, S, C]\n", " # q, k, v [B, H, S, C/H]\n", " q, k, v = (\n", " self.qkv_layer(x)\n", " .view(B, S, self.n_heads, 3 * (C // self.n_heads))\n", " .transpose(1, 2)\n", " .chunk(chunks=3, dim=3)\n", " )\n", "\n", " # Let us enforce either flash (A100+) or memory efficient attention.\n", " if version(\"torch\") > \"2.3.0\":\n", " with sdpa_kernel(\n", " [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]\n", " ):\n", " # x [B, H, S, C//H]\n", " x = F.scaled_dot_product_attention(\n", " q, k, v, attn_mask=attn_mask, dropout_p=self.dropout\n", " )\n", " else:\n", " with torch.backends.cuda.sdp_kernel(\n", " enable_flash=True, enable_math=False, enable_mem_efficient=True\n", " ):\n", " # x [B, H, S, C//H]\n", " x = F.scaled_dot_product_attention(\n", " q, k, v, dropout_p=self.dropout\n", " )\n", "\n", " # x [B, S, C]\n", " x = x.transpose(1, 2).view(B, S, C)\n", "\n", " # x [B, S, C]\n", " x = self.w_layer(x)\n", "\n", " # Back to input shape\n", " x = x.view(*passenger_dims, S, self.features)\n", " return x\n", "\n", "\n", "class Transformer(nn.Module):\n", " \"\"\"\n", " Transformer for inputs of shape [..., S, features].\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " features: int,\n", " mlp_multiplier: int,\n", " n_heads: int,\n", " dropout: float,\n", " drop_path: float,\n", " ) -> None:\n", " \"\"\"\n", " Args:\n", " features: Number of features for inputs to the layer.\n", " mlp_multiplier: Model uses features*mlp_multiplier hidden units.\n", " n_heads: Number of attention heads. Should be a factor of features.\n", " (I.e. the layer uses features // n_heads.) dropout: Dropout.\n", " drop_path: DropPath.\n", " \"\"\"\n", " super().__init__()\n", "\n", " self.features = features\n", " self.mlp_multiplier = mlp_multiplier\n", " self.n_heads = n_heads\n", " self.dropout = dropout\n", " self.drop_path = (\n", " DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n", " )\n", "\n", " self.attention = nn.Sequential(\n", " LayerNormPassThrough(features),\n", " MultiheadAttention(features, n_heads, dropout),\n", " )\n", "\n", " self.ff = nn.Sequential(\n", " nn.LayerNorm(features),\n", " Mlp(\n", " features=features,\n", " hidden_features=features * mlp_multiplier,\n", " dropout=dropout,\n", " ),\n", " )\n", "\n", " def forward(self, d: tuple[Tensor, Tensor | None]) -> Tensor:\n", " \"\"\"\n", " Args:\n", " x: Tensor of shape [..., sequence, features]\n", " Returns:\n", " Tensor: Tensor of shape [..., sequence, features]\n", " \"\"\"\n", " x, attn_mask = d\n", " if not x.shape[-1] == self.features:\n", " raise ValueError(\n", " f\"Expecting tensor with last dimension size {self.features}.\"\n", " )\n", "\n", " attention_x = self.attention(d)\n", "\n", " x = x + self.drop_path(attention_x)\n", " x = x + self.drop_path(self.ff(x))\n", "\n", " return x\n", "\n", "\n", "class _Shift(nn.Module):\n", " \"\"\"Private base class for the shifter. This allows some behaviour to be\n", " easily handled when the shifter isn't used.\n", " \"\"\"\n", "\n", " def __init__(self):\n", " super().__init__()\n", "\n", " self._shifted = False\n", "\n", " @torch.no_grad()\n", " def reset(self) -> None:\n", " \"\"\"\n", " Resets the bool tracking whether the data is shifted\n", " \"\"\"\n", " self._shifted: bool = False\n", "\n", " def forward(self, data: Tensor) -> tuple[Tensor, dict[bool, None]]:\n", " return data, {True: None, False: None}\n", "\n", "\n", "class SWINShift(_Shift):\n", " \"\"\"\n", " Handles the shifting of patches similar to how SWIN works. However if we\n", " shift the latitudes then the poles will wrap and potentially that might be\n", " problematic. The possition tokens should handle it but masking is safer.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " mu_shape: tuple[int, int],\n", " global_shape: tuple[int, int],\n", " local_shape: tuple[int, int],\n", " patch_shape: tuple[int, int],\n", " n_context_tokens: int = 2,\n", " ) -> None:\n", " \"\"\"\n", " Args:\n", " mu_shape: the shape to the masking units\n", " global_shape: number of global patches in lat and lon\n", " local_shape: size of the local patches\n", " patch_shape: patch size\n", " n_context_token: number of additional context tokens at start of\n", " _each_ local sequence\n", " \"\"\"\n", " super().__init__()\n", "\n", " self._mu_shape = ms = mu_shape\n", " self._g_shape = gs = global_shape\n", " self._l_shape = ls = local_shape\n", " self._p_shape = ps = patch_shape\n", " self._lat_patch = (gs[0], ls[0], gs[1], ls[1])\n", " self._n_context_tokens = n_context_tokens\n", "\n", " self._g_shift_to = tuple(\n", " int(0.5 * x / p) for x, p in zip(ms, ps, strict=False)\n", " )\n", " self._g_shift_from = tuple(\n", " -int(0.5 * x / p) for x, p in zip(ms, ps, strict=False)\n", " )\n", "\n", " # Define the attention masks for the shifted MaxViT.\n", " nglobal = global_shape[0] * global_shape[1]\n", " nlocal = (\n", " local_shape[0] * local_shape[1] + self._n_context_tokens\n", " ) # \"+ 1\" for leadtime\n", "\n", " lm = torch.ones((nglobal, 1, nlocal, nlocal), dtype=bool)\n", " mwidth = int(0.5 * local_shape[1]) * local_shape[0]\n", " lm[\n", " : gs[1],\n", " :,\n", " self._n_context_tokens : mwidth + self._n_context_tokens,\n", " self._n_context_tokens : mwidth + self._n_context_tokens,\n", " ] = False\n", " self.register_buffer(\"local_mask\", lm)\n", "\n", " gm = torch.ones((nlocal, 1, nglobal, nglobal), dtype=bool)\n", " gm[: int(0.5 * ls[1]) * ls[0], :, : gs[1], : gs[1]] = False\n", " self.register_buffer(\"global_mask\", gm)\n", "\n", " def _to_grid_global(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Shuffle and reshape the data from the global/local setting back to the\n", " lat/lon grid setting\n", " Args:\n", " x: the data tensor to be shuffled.\n", " Returns:\n", " x: data in the global/local setting\n", " \"\"\"\n", " nbatch, *other = x.shape\n", "\n", " y1 = x.view(nbatch, *self._g_shape, *self._l_shape, -1)\n", " y2 = y1.permute(0, 5, 1, 3, 2, 4).contiguous()\n", "\n", " s = y2.shape\n", " return y2.view((nbatch, -1, s[2] * s[3], s[4] * s[5]))\n", "\n", " def _to_grid_local(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Shuffle and reshape the data from the local/global setting to the\n", " lat/lon grid setting\n", " Args:\n", " x: the data tensor to be shuffled.\n", " Returns:\n", " x: data in the lat/lon setting.\n", " \"\"\"\n", " x = x.transpose(2, 1).contiguous()\n", " return self._to_grid_global(x)\n", "\n", " def _from_grid_global(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Shuffle and reshape the data from the lat/lon grid to the global/local\n", " setting\n", " Args:\n", " x: the data tensor to be shuffled.\n", " Returns:\n", " x: data in the global/local setting\n", " \"\"\"\n", " nbatch, *other = x.shape\n", "\n", " z1 = x.view(nbatch, -1, *self._lat_patch)\n", " z2 = z1.permute(0, 2, 4, 3, 5, 1).contiguous()\n", "\n", " s = z2.shape\n", " return z2.view(nbatch, s[1] * s[2], s[3] * s[4], -1)\n", "\n", " def _from_grid_local(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Shuffle and reshape the data from the lat/lon grid to the local/global\n", " setting\n", " Args:\n", " x: the data tensor to be shuffled.\n", " Returns:\n", " x: data in the local/global setting\n", " \"\"\"\n", " x = self._from_grid_global(x)\n", " return x.transpose(2, 1).contiguous()\n", "\n", " def _shift(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Shifts data in the gridded lat/lon setting by half the mask unit shape\n", " Args:\n", " x: data to be shifted\n", " Returns:\n", " x: either the hsifted or unshifted data\n", " \"\"\"\n", " shift = self._g_shift_from if self._shifted else self._g_shift_to\n", " x_shifted = torch.roll(x, shift, (-2, -1))\n", "\n", " self._shifted = not self._shifted\n", " return x_shifted\n", "\n", " def _sep_lt(self, x: Tensor) -> tuple[Tensor, Tensor]:\n", " \"\"\"\n", " Seperate off the leadtime from the local patches\n", " Args:\n", " x: data to have leadtime removed from\n", " Returns:\n", " lt: leadtime\n", " x: data without the lead time in the local patch\n", " \"\"\"\n", " lt_it = x[:, : self._n_context_tokens, :, :]\n", " x_stripped = x[:, self._n_context_tokens :, :, :]\n", "\n", " return lt_it, x_stripped\n", "\n", " def forward(self, data: Tensor) -> tuple[Tensor, Tensor]:\n", " \"\"\"Shift or unshift the the data depending on whether the data is\n", " already shifted, as defined by self._shifte.\n", "\n", " Args:\n", " data: data to be shifted\n", " Returns:\n", " Tensor: shifted data Tensor\n", " \"\"\"\n", " lt, x = self._sep_lt(data)\n", "\n", " x_grid = self._to_grid_local(x)\n", " x_shifted = self._shift(x_grid)\n", " x_patched = self._from_grid_local(x_shifted)\n", "\n", " # Mask has to be repeated based on batch size\n", " n_batch = x_grid.shape[0]\n", " local_rep = [n_batch] + [1] * (self.local_mask.ndim - 1)\n", " global_rep = [n_batch] + [1] * (self.global_mask.ndim - 1)\n", "\n", " if self._shifted:\n", " attn_mask = {\n", " True: self.local_mask.repeat(local_rep),\n", " False: self.global_mask.repeat(global_rep),\n", " }\n", " else:\n", " attn_mask = {True: None, False: None}\n", "\n", " return torch.cat((lt, x_patched), axis=1), attn_mask\n", "\n", "\n", "class LocalGlobalLocalBlock(nn.Module):\n", " \"\"\"\n", " Applies alternating block and grid attention. Given a parameter n_blocks,\n", " the entire module contains 2*n_blocks+1 transformer blocks. The first,\n", " third, ..., last apply local (block) attention. The second, fourth, ...\n", " global (grid) attention.\n", "\n", " This is heavily inspired by\n", " Tu et al. \"MaxViT: Multi-Axis Vision Transformer\"\n", " (https://arxiv.org/abs/2204.01697).\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " features: int,\n", " mlp_multiplier: int,\n", " n_heads: int,\n", " dropout: float,\n", " n_blocks: int,\n", " drop_path: float,\n", " shifter: nn.Module | None = None,\n", " checkpoint: list[int] | None = None,\n", " ) -> None:\n", " \"\"\"\n", " Args:\n", " features: Number of features for inputs to the layer.\n", " mlp_multiplier: Model uses features*mlp_multiplier hidden units.\n", " n_heads: Number of attention heads. Should be a factor of features.\n", " (I.e. the layer uses features // n_heads.)\n", " dropout: Dropout.\n", " drop_path: DropPath.\n", " n_blocks: Number of local-global transformer pairs.\n", " \"\"\"\n", " super().__init__()\n", "\n", " self.features = features\n", " self.mlp_multiplier = mlp_multiplier\n", " self.n_heads = n_heads\n", " self.dropout = dropout\n", " self.drop_path = drop_path\n", " self.n_blocks = n_blocks\n", " self._checkpoint = checkpoint or []\n", "\n", " if not all(0 <= c < 2 * n_blocks + 1 for c in self._checkpoint):\n", " raise ValueError(\n", " \"Checkpoints should be 0 <= i < 2*n_blocks+1. \"\n", " f\"{self._checkpoint=}.\"\n", " )\n", "\n", " self.transformers = nn.ModuleList(\n", " [\n", " Transformer(\n", " features=features,\n", " mlp_multiplier=mlp_multiplier,\n", " n_heads=n_heads,\n", " dropout=dropout,\n", " drop_path=drop_path,\n", " )\n", " for _ in range(2 * n_blocks + 1)\n", " ]\n", " )\n", "\n", " self.evaluator = [\n", " self._checkpoint_wrapper\n", " if i in self._checkpoint\n", " else lambda m, x: m(x)\n", " for i, _ in enumerate(self.transformers)\n", " ]\n", "\n", " self.shifter = shifter or _Shift()\n", "\n", " @staticmethod\n", " def _checkpoint_wrapper(\n", " model: nn.Module, data: tuple[Tensor, Tensor | None]\n", " ) -> Tensor:\n", " return checkpoint(model, data, use_reentrant=False)\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Args:\n", " x: Tensor of shape::\n", "\n", " [batch, global_sequence, local_sequence, features]\n", "\n", " Returns:\n", " Tensor: Tensor of shape::\n", "\n", " [batch, global_sequence, local_sequence, features]\n", " \"\"\"\n", " if x.shape[-1] != self.features:\n", " raise ValueError(\n", " f\"Expecting tensor with last dimension size {self.features}.\"\n", " )\n", " if x.ndim != 4:\n", " raise ValueError(\n", " f\"Expecting tensor with exactly four dimensions. {x.shape=}.\"\n", " )\n", "\n", " self.shifter.reset()\n", " local: bool = True\n", " attn_mask = {True: None, False: None}\n", "\n", " transformer_iter = zip(self.evaluator, self.transformers, strict=False)\n", "\n", " # First local block\n", " evaluator, transformer = next(transformer_iter)\n", " x = evaluator(transformer, (x, attn_mask[local]))\n", "\n", " for evaluator, transformer in transformer_iter:\n", " local = not local\n", " # We are making exactly 2*n_blocks transposes.\n", " # So the output has the same shape as input.\n", " x = x.transpose(1, 2)\n", "\n", " x = evaluator(transformer, (x, attn_mask[local]))\n", "\n", " if not local:\n", " x, attn_mask = self.shifter(x)\n", "\n", " return x\n", "\n", "\n", "class PatchEmbed(nn.Module):\n", " \"\"\"\n", " Patch embedding via 2D convolution.\n", " \"\"\"\n", "\n", " def __init__(\n", " self, patch_size: int | tuple[int, ...], channels: int, embed_dim: int\n", " ):\n", " super().__init__()\n", "\n", " self.patch_size = patch_size\n", " self.channels = channels\n", " self.embed_dim = embed_dim\n", "\n", " self.proj = nn.Conv2d(\n", " channels,\n", " embed_dim,\n", " kernel_size=patch_size,\n", " stride=patch_size,\n", " bias=True,\n", " )\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Args:\n", " x: Tensor of shape [batch, channels, lat, lon].\n", " Returns:\n", " Tensor: Tensor with shape\n", " [batch, embed_dim, lat//patch_size, lon//patch_size]\n", " \"\"\"\n", "\n", " H, W = x.shape[-2:]\n", "\n", " if W % self.patch_size[1] != 0:\n", " raise ValueError(\n", " f\"Cannot do patch embedding for tensor of shape {x.size()}\"\n", " \" with patch size {self.patch_size}. (Dimensions are BSCHW.)\"\n", " )\n", " if H % self.patch_size[0] != 0:\n", " raise ValueError(\n", " f\"Cannot do patch embedding for tensor of shape {x.size()}\"\n", " f\" with patch size {self.patch_size}. (Dimensions are BSCHW.)\"\n", " )\n", "\n", " x = self.proj(x)\n", "\n", " return x\n", "\n", "\n", "class PrithviWxCEncoderDecoder(nn.Module):\n", " \"\"\"\n", " Hiera-MaxViT encoder/decoder code.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " embed_dim: int,\n", " n_blocks: int,\n", " mlp_multiplier: float,\n", " n_heads: int,\n", " dropout: float,\n", " drop_path: float,\n", " shifter: nn.Module | None = None,\n", " transformer_cp: list[int] | None = None,\n", " ) -> None:\n", " \"\"\"\n", " Args:\n", " embed_dim: Embedding dimension\n", " n_blocks: Number of local-global transformer pairs.\n", " mlp_multiplier: MLP multiplier for hidden features in feed forward\n", " networks.\n", " n_heads: Number of attention heads.\n", " dropout: Dropout.\n", " drop_path: DropPath.\n", " \"\"\"\n", " super().__init__()\n", "\n", " self.embed_dim = embed_dim\n", " self.n_blocks = n_blocks\n", " self.mlp_multiplier = mlp_multiplier\n", " self.n_heads = n_heads\n", " self.dropout = dropout\n", " self._transformer_cp = transformer_cp\n", "\n", " self.lgl_block = LocalGlobalLocalBlock(\n", " features=embed_dim,\n", " mlp_multiplier=mlp_multiplier,\n", " n_heads=n_heads,\n", " dropout=dropout,\n", " drop_path=drop_path,\n", " n_blocks=n_blocks,\n", " shifter=shifter,\n", " checkpoint=transformer_cp,\n", " )\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " \"\"\"\n", " Args:\n", " x: Tensor of shape\n", " [batch, global sequence, local sequence, embed_dim]\n", " Returns:\n", " Tensor of shape\n", " [batch, mask_unit_sequence, local_sequence, embed_dim].\n", " Identical in shape to the input x.\n", " \"\"\"\n", "\n", " x = self.lgl_block(x)\n", "\n", " return x\n", "\n", "\n", "class PrithviWxC(nn.Module):\n", " \"\"\"Encoder-decoder fusing Hiera with MaxViT. See\n", " - Ryali et al. \"Hiera: A Hierarchical Vision Transformer without the\n", " Bells-and-Whistles\" (https://arxiv.org/abs/2306.00989)\n", " - Tu et al. \"MaxViT: Multi-Axis Vision Transformer\"\n", " (https://arxiv.org/abs/2204.01697)\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " in_channels: int,\n", " input_size_time: int,\n", " in_channels_static: int,\n", " input_scalers_mu: Tensor,\n", " input_scalers_sigma: Tensor,\n", " input_scalers_epsilon: float,\n", " static_input_scalers_mu: Tensor,\n", " static_input_scalers_sigma: Tensor,\n", " static_input_scalers_epsilon: float,\n", " output_scalers: Tensor,\n", " n_lats_px: int,\n", " n_lons_px: int,\n", " patch_size_px: tuple[int],\n", " mask_unit_size_px: tuple[int],\n", " mask_ratio_inputs: float,\n", " embed_dim: int,\n", " n_blocks_encoder: int,\n", " n_blocks_decoder: int,\n", " mlp_multiplier: float,\n", " n_heads: int,\n", " dropout: float,\n", " drop_path: float,\n", " parameter_dropout: float,\n", " residual: str,\n", " masking_mode: str,\n", " positional_encoding: str,\n", " decoder_shifting: bool = False,\n", " checkpoint_encoder: list[int] | None = None,\n", " checkpoint_decoder: list[int] | None = None,\n", " ) -> None:\n", " \"\"\"\n", " Args:\n", " in_channels: Number of input channels.\n", " input_size_time: Number of timestamps in input.\n", " in_channels_static: Number of input channels for static data.\n", " input_scalers_mu: Tensor of size (in_channels,). Used to rescale\n", " input.\n", " input_scalers_sigma: Tensor of size (in_channels,). Used to rescale\n", " input.\n", " input_scalers_epsilon: Float. Used to rescale input.\n", " static_input_scalers_mu: Tensor of size (in_channels_static). Used\n", " to rescale static inputs.\n", " static_input_scalers_sigma: Tensor of size (in_channels_static).\n", " Used to rescale static inputs.\n", " static_input_scalers_epsilon: Float. Used to rescale static inputs.\n", " output_scalers: Tensor of shape (in_channels,). Used to rescale\n", " output.\n", " n_lats_px: Total latitudes in data. In pixels.\n", " n_lons_px: Total longitudes in data. In pixels.\n", " patch_size_px: Patch size for tokenization. In pixels lat/lon.\n", " mask_unit_size_px: Size of each mask unit. In pixels lat/lon.\n", " mask_ratio_inputs: Masking ratio for inputs. 0 to 1.\n", " embed_dim: Embedding dimension\n", " n_blocks_encoder: Number of local-global transformer pairs in\n", " encoder.\n", " n_blocks_decoder: Number of local-global transformer pairs in\n", " decoder.\n", " mlp_multiplier: MLP multiplier for hidden features in feed forward\n", " networks.\n", " n_heads: Number of attention heads.\n", " dropout: Dropout.\n", " drop_path: DropPath.\n", " parameter_dropout: Dropout applied to parameters.\n", " residual: Indicates whether and how model should work as residual\n", " model. Accepted values are 'climate', 'temporal' and 'none'\n", " positional_encoding: possible values are\n", " ['absolute' (default), 'fourier'].\n", " 'absolute' lat lon encoded in 3 dimensions using sine and\n", " cosine\n", " 'fourier' lat/lon to be encoded using various frequencies\n", " masking_mode: String ['local', 'global', 'both'] that controls the\n", " type of masking used.\n", " checkpoint_encoder: List of integers controlling if gradient\n", " checkpointing is used on encoder.\n", " Format: [] for no gradient checkpointing. [3, 7] for\n", " checkpointing after 4th and 8th layer etc.\n", " checkpoint_decoder: List of integers controlling if gradient\n", " checkpointing is used on decoder.\n", " Format: See `checkpoint_encoder`.\n", " masking_mode: The type of masking to use\n", " {'global', 'local', 'both'}\n", " decoder_shifting: Whether to use swin shifting in the decoder.\n", " \"\"\"\n", " super().__init__()\n", "\n", " self.in_channels = in_channels\n", " self.input_size_time = input_size_time\n", " self.in_channels_static = in_channels_static\n", " self.n_lats_px = n_lats_px\n", " self.n_lons_px = n_lons_px\n", " self.patch_size_px = patch_size_px\n", " self.mask_unit_size_px = mask_unit_size_px\n", " self.mask_ratio_inputs = mask_ratio_inputs\n", " self.embed_dim = embed_dim\n", " self.n_blocks_encoder = n_blocks_encoder\n", " self.n_blocks_decoder = n_blocks_decoder\n", " self.mlp_multiplier = mlp_multiplier\n", " self.n_heads = n_heads\n", " self.dropout = dropout\n", " self.drop_path = drop_path\n", " self.residual = residual\n", " self._decoder_shift = decoder_shifting\n", " self.positional_encoding = positional_encoding\n", " self._checkpoint_encoder = checkpoint_encoder\n", " self._checkpoint_decoder = checkpoint_decoder\n", "\n", " assert self.n_lats_px % self.mask_unit_size_px[0] == 0\n", " assert self.n_lons_px % self.mask_unit_size_px[1] == 0\n", " assert self.mask_unit_size_px[0] % self.patch_size_px[0] == 0\n", " assert self.mask_unit_size_px[1] % self.patch_size_px[1] == 0\n", "\n", " if self.patch_size_px[0] != self.patch_size_px[1]:\n", " raise NotImplementedError(\n", " \"Current pixel shuffle symmetric patches.\"\n", " )\n", "\n", " self.local_shape_mu = (\n", " self.mask_unit_size_px[0] // self.patch_size_px[0],\n", " self.mask_unit_size_px[1] // self.patch_size_px[1],\n", " )\n", " self.global_shape_mu = (\n", " self.n_lats_px // self.mask_unit_size_px[0],\n", " self.n_lons_px // self.mask_unit_size_px[1],\n", " )\n", "\n", " assert input_scalers_mu.shape == (in_channels,)\n", " assert input_scalers_sigma.shape == (in_channels,)\n", " assert output_scalers.shape == (in_channels,)\n", "\n", " if self.positional_encoding != \"fourier\":\n", " assert static_input_scalers_mu.shape == (in_channels_static,)\n", " assert static_input_scalers_sigma.shape == (in_channels_static,)\n", "\n", " # Input shape [batch, time, parameter, lat, lon]\n", " self.input_scalers_epsilon = input_scalers_epsilon\n", " self.register_buffer(\n", " \"input_scalers_mu\", input_scalers_mu.reshape(1, 1, -1, 1, 1)\n", " )\n", " self.register_buffer(\n", " \"input_scalers_sigma\", input_scalers_sigma.reshape(1, 1, -1, 1, 1)\n", " )\n", "\n", " # Static inputs shape [batch, parameter, lat, lon]\n", " self.static_input_scalers_epsilon = static_input_scalers_epsilon\n", " self.register_buffer(\n", " \"static_input_scalers_mu\",\n", " static_input_scalers_mu.reshape(1, -1, 1, 1),\n", " )\n", " self.register_buffer(\n", " \"static_input_scalers_sigma\",\n", " static_input_scalers_sigma.reshape(1, -1, 1, 1),\n", " )\n", "\n", " # Output shape [batch, parameter, lat, lon]\n", " self.register_buffer(\n", " \"output_scalers\", output_scalers.reshape(1, -1, 1, 1)\n", " )\n", "\n", " self.parameter_dropout = nn.Dropout2d(p=parameter_dropout)\n", "\n", " self.patch_embedding = PatchEmbed(\n", " patch_size=patch_size_px,\n", " channels=in_channels * input_size_time,\n", " embed_dim=embed_dim,\n", " )\n", "\n", " if self.residual == \"climate\":\n", " self.patch_embedding_static = PatchEmbed(\n", " patch_size=patch_size_px,\n", " channels=in_channels + in_channels_static,\n", " embed_dim=embed_dim,\n", " )\n", " else:\n", " self.patch_embedding_static = PatchEmbed(\n", " patch_size=patch_size_px,\n", " channels=in_channels_static,\n", " embed_dim=embed_dim,\n", " )\n", "\n", " self.input_time_embedding = nn.Linear(1, embed_dim // 4, bias=True)\n", " self.lead_time_embedding = nn.Linear(1, embed_dim // 4, bias=True)\n", "\n", " self.mask_token = nn.Parameter(torch.randn(1, 1, 1, self.embed_dim))\n", " self._nglobal_mu = np.prod(self.global_shape_mu)\n", " self._global_idx = torch.arange(self._nglobal_mu)\n", "\n", " self._nlocal_mu = np.prod(self.local_shape_mu)\n", " self._local_idx = torch.arange(self._nlocal_mu)\n", "\n", " self.encoder = PrithviWxCEncoderDecoder(\n", " embed_dim=embed_dim,\n", " n_blocks=n_blocks_encoder,\n", " mlp_multiplier=mlp_multiplier,\n", " n_heads=n_heads,\n", " dropout=dropout,\n", " drop_path=drop_path,\n", " transformer_cp=checkpoint_encoder,\n", " )\n", "\n", " if n_blocks_decoder != 0:\n", " if self._decoder_shift:\n", " self.decoder_shifter = d_shifter = SWINShift(\n", " self.mask_unit_size_px,\n", " self.global_shape_mu,\n", " self.local_shape_mu,\n", " self.patch_size_px,\n", " n_context_tokens=0,\n", " )\n", " else:\n", " self.decoder_shifter = d_shifter = None\n", "\n", " self.decoder = PrithviWxCEncoderDecoder(\n", " embed_dim=embed_dim,\n", " n_blocks=n_blocks_decoder,\n", " mlp_multiplier=mlp_multiplier,\n", " n_heads=n_heads,\n", " dropout=dropout,\n", " drop_path=0.0,\n", " shifter=d_shifter,\n", " transformer_cp=checkpoint_decoder,\n", " )\n", "\n", " self.unembed = nn.Linear(\n", " self.embed_dim,\n", " self.in_channels\n", " * self.patch_size_px[0]\n", " * self.patch_size_px[1],\n", " bias=True,\n", " )\n", "\n", " self.masking_mode = masking_mode.lower()\n", " match self.masking_mode:\n", " case \"local\":\n", " self.generate_mask = self._gen_mask_local\n", " case \"global\":\n", " self.generate_mask = self._gen_mask_global\n", " case \"both\":\n", " self._mask_both_local: bool = True\n", " self.generate_mask = self._gen_mask_both\n", " case _:\n", " raise ValueError(\n", " f\"Masking mode '{masking_mode}' not supported\"\n", " )\n", "\n", " def swap_masking(self) -> None:\n", " self._mask_both_local = not self._mask_both_local\n", "\n", " @cached_property\n", " def n_masked_global(self):\n", " return int(self.mask_ratio_inputs * np.prod(self.global_shape_mu))\n", "\n", " @cached_property\n", " def n_masked_local(self):\n", " return int(self.mask_ratio_inputs * np.prod(self.local_shape_mu))\n", "\n", " @staticmethod\n", " def _shuffle_along_axis(a, axis):\n", " idx = torch.argsort(input=torch.rand(*a.shape), dim=axis)\n", " return torch.gather(a, dim=axis, index=idx)\n", "\n", " def _gen_mask_local(self, sizes: tuple[int]) -> tuple[Tensor]:\n", " \"\"\"\n", " Args:\n", " batch_size: Number of elements in batch\n", " Returns:\n", " Tuple of torch tensors. [indices masked, indices unmasked].\n", " Each of these is a tensor of shape (batch, global sequene)\n", " \"\"\"\n", " # Identify which indices (values) should be masked\n", "\n", " maskable_indices = self._local_idx.view(1, -1).expand(*sizes[:2], -1)\n", "\n", " maskable_indices = self._shuffle_along_axis(maskable_indices, 2)\n", "\n", " indices_masked = maskable_indices[:, :, : self.n_masked_local]\n", " indices_unmasked = maskable_indices[:, :, self.n_masked_local :]\n", "\n", " return indices_masked, indices_unmasked\n", "\n", " def _gen_mask_global(self, sizes: tuple[int]) -> tuple[Tensor]:\n", " \"\"\"\n", " Args:\n", " batch_size: Number of elements in batch\n", " Returns:\n", " Tuple of torch tensors. [indices masked, indices unmasked].\n", " Each of these is a tensor of shape (batch, global sequene)\n", " \"\"\"\n", " # Identify which indices (values) should be masked\n", "\n", " maskable_indices = self._global_idx.view(1, -1).expand(*sizes[:1], -1)\n", "\n", " maskable_indices = self._shuffle_along_axis(maskable_indices, 1)\n", "\n", " indices_masked = maskable_indices[:, : self.n_masked_global]\n", " indices_unmasked = maskable_indices[:, self.n_masked_global :]\n", "\n", " return indices_masked, indices_unmasked\n", "\n", " def _gen_mask_both(self, sizes: tuple[int]) -> tuple[Tensor]:\n", " if self._mask_both_local:\n", " return self._gen_mask_local(sizes)\n", " else:\n", " return self._gen_mask_global(sizes)\n", "\n", " @staticmethod\n", " def reconstruct_batch(\n", " idx_masked: Tensor,\n", " idx_unmasked: Tensor,\n", " data_masked: Tensor,\n", " data_unmasked: Tensor,\n", " ) -> Tensor:\n", " \"\"\"Reconstructs a tensor along the mask unit dimension. Batched\n", " version.\n", "\n", " Args:\n", " idx_masked: Tensor of shape `batch, mask unit sequence`.\n", " idx_unmasked: Tensor of shape `batch, mask unit sequence`.\n", " data_masked: Tensor of shape `batch, mask unit sequence, ...`.\n", " Should have same size along mask unit sequence dimension as\n", " idx_masked. Dimensions beyond the first two, marked here as ...\n", " will typically be `local_sequence, channel` or\n", " `channel, lat, lon`. These dimensions should agree with\n", " data_unmasked.\n", " data_unmasked: Tensor of shape `batch, mask unit sequence, ...`.\n", " Should have same size along mask unit sequence dimension as\n", " idx_unmasked. Dimensions beyond the first two, marked here as\n", " ... will typically be `local_sequence, channel` or `channel,\n", " lat, lon`. These dimensions should agree with data_masked.\n", " Returns:\n", " Tensor: Tensor of same shape as inputs data_masked and\n", " data_unmasked. I.e. `batch, mask unit sequence, ...`. Index for\n", " the total data composed of the masked and the unmasked part.\n", " \"\"\"\n", " dim: int = idx_masked.ndim\n", "\n", " idx_total = torch.argsort(\n", " torch.cat([idx_masked, idx_unmasked], dim=-1), dim=-1\n", " )\n", " idx_total = idx_total.view(\n", " *idx_total.shape, *[1] * (data_unmasked.ndim - dim)\n", " )\n", " idx_total = idx_total.expand(\n", " *idx_total.shape[:dim], *data_unmasked.shape[dim:]\n", " )\n", "\n", " data = torch.cat([data_masked, data_unmasked], dim=dim - 1)\n", " data = torch.gather(data, dim=dim - 1, index=idx_total)\n", "\n", " return data, idx_total\n", "\n", " def fourier_pos_encoding(self, x_static: Tensor) -> Tensor:\n", " \"\"\"\n", " Args\n", " x_static: B x C x H x W. first two channels are lat, and lon\n", " Returns\n", " Tensor: Tensor of shape B x E x H x W where E is the embedding\n", " dimension.\n", " \"\"\"\n", "\n", " # B x C x H x W -> B x 1 x H/P x W/P\n", " latitudes_patch = F.avg_pool2d(\n", " x_static[:, [0]],\n", " kernel_size=self.patch_size_px,\n", " stride=self.patch_size_px,\n", " )\n", " longitudes_patch = F.avg_pool2d(\n", " x_static[:, [1]],\n", " kernel_size=self.patch_size_px,\n", " stride=self.patch_size_px,\n", " )\n", "\n", " modes = (\n", " torch.arange(self.embed_dim // 4, device=x_static.device).view(\n", " 1, -1, 1, 1\n", " )\n", " + 1.0\n", " )\n", " pos_encoding = torch.cat(\n", " (\n", " torch.sin(latitudes_patch * modes),\n", " torch.sin(longitudes_patch * modes),\n", " torch.cos(latitudes_patch * modes),\n", " torch.cos(longitudes_patch * modes),\n", " ),\n", " axis=1,\n", " )\n", "\n", " return pos_encoding # B x E x H/P x W/P\n", "\n", " def time_encoding(self, input_time, lead_time):\n", " \"\"\"\n", " Args:\n", " input_time: Tensor of shape [batch].\n", " lead_time: Tensor of shape [batch].\n", " Returns:\n", " Tensor: Tensor of shape [batch, embed_dim, 1, 1]\n", " \"\"\"\n", " input_time = self.input_time_embedding(input_time.view(-1, 1, 1, 1))\n", " lead_time = self.lead_time_embedding(lead_time.view(-1, 1, 1, 1))\n", "\n", " time_encoding = torch.cat(\n", " (\n", " torch.cos(input_time),\n", " torch.cos(lead_time),\n", " torch.sin(input_time),\n", " torch.sin(lead_time),\n", " ),\n", " axis=3,\n", " )\n", " return time_encoding\n", "\n", " def to_patching(self, x: Tensor) -> Tensor:\n", " \"\"\"Transform data from lat/lon space to two axis patching\n", "\n", " Args: ->\n", " x: Tesnor in lat/lon space (N, C, Nlat//P_0, Nlon//P_1)\n", "\n", " Returns:\n", " Tensor in patch space (N, G, L, C)\n", " \"\"\"\n", " n_batch = x.shape[0]\n", "\n", " x = x.view(\n", " n_batch,\n", " -1,\n", " self.global_shape_mu[0],\n", " self.local_shape_mu[0],\n", " self.global_shape_mu[1],\n", " self.local_shape_mu[1],\n", " )\n", " x = x.permute(0, 2, 4, 3, 5, 1).contiguous()\n", "\n", " s = x.shape\n", " return x.view(n_batch, s[1] * s[2], s[3] * s[4], -1)\n", "\n", " def from_patching(self, x: Tensor) -> Tensor:\n", " \"\"\"Transform data from two axis patching to lat/lon space\n", "\n", " Args:\n", " x: Tensor in patch space with shape (N, G, L, C*P_0*P_1)\n", "\n", " Returns:\n", " Tensor: Tensor in lat/lon space\n", " (N, C*P_0*P_1, Nlat//P_0, Nlon // P_1)\n", " \"\"\"\n", " n_batch = x.shape[0]\n", "\n", " x = x.view(\n", " n_batch,\n", " self.global_shape_mu[0],\n", " self.global_shape_mu[1],\n", " self.local_shape_mu[0],\n", " self.local_shape_mu[1],\n", " -1,\n", " )\n", " x = x.permute(0, 5, 1, 3, 2, 4).contiguous()\n", "\n", " s = x.shape\n", " return x.view(n_batch, -1, s[2] * s[3], s[4] * s[5])\n", "\n", " def forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:\n", " \"\"\"\n", " Args:\n", " batch: Dictionary the following keys::\n", "\n", " 'x': Tensor of shape [batch, time, parameter, lat, lon]\n", " 'y': Tensor of shape [batch, parameter, lat, lon]\n", " 'static': Tensor of shape [batch, channel_static, lat, lon]\n", " 'climate': Optional tensor of shape [batch, parameter, lat, lon]\n", " 'input_time': Tensor of shape [batch]. Or none.\n", " 'lead_time': Tensor of shape [batch]. Or none.\n", "\n", " Returns:\n", " Tensor: Tensor of shape [batch, parameter, lat, lon].\n", " \"\"\" # noqa: E501\n", " x_rescaled = (batch[\"x\"] - self.input_scalers_mu) / (\n", " self.input_scalers_sigma + self.input_scalers_epsilon\n", " )\n", " batch_size = x_rescaled.shape[0]\n", "\n", " if self.positional_encoding == \"fourier\":\n", " x_static_pos = self.fourier_pos_encoding(batch[\"static\"])\n", " x_static = (\n", " batch[\"static\"][:, 2:] - self.static_input_scalers_mu[:, 3:]\n", " ) / (\n", " self.static_input_scalers_sigma[:, 3:]\n", " + self.static_input_scalers_epsilon\n", " )\n", " else:\n", " x_static = (batch[\"static\"] - self.static_input_scalers_mu) / (\n", " self.static_input_scalers_sigma\n", " + self.static_input_scalers_epsilon\n", " )\n", "\n", " if self.residual == \"temporal\":\n", " # We create a residual of same shape as y\n", " index = torch.where(\n", " batch[\"lead_time\"] > 0, batch[\"x\"].shape[1] - 1, 0\n", " )\n", " index = index.view(-1, 1, 1, 1, 1)\n", " index = index.expand(batch_size, 1, *batch[\"x\"].shape[2:])\n", " x_hat = torch.gather(batch[\"x\"], dim=1, index=index)\n", " x_hat = x_hat.squeeze(1)\n", " elif self.residual == \"climate\":\n", " climate_scaled = (\n", " batch[\"climate\"] - self.input_scalers_mu.view(1, -1, 1, 1)\n", " ) / (\n", " self.input_scalers_sigma.view(1, -1, 1, 1)\n", " + self.input_scalers_epsilon\n", " )\n", "\n", " # [batch, time, parameter, lat, lon]\n", " # -> [batch, time x parameter, lat, lon]\n", " x_rescaled = x_rescaled.flatten(1, 2)\n", " # Parameter dropout\n", " x_rescaled = self.parameter_dropout(x_rescaled)\n", "\n", " x_embedded = self.patch_embedding(x_rescaled)\n", "\n", " if self.residual == \"climate\":\n", " static_embedded = self.patch_embedding_static(\n", " torch.cat((x_static, climate_scaled), dim=1)\n", " )\n", " else:\n", " static_embedded = self.patch_embedding_static(x_static)\n", "\n", " if self.positional_encoding == \"fourier\":\n", " static_embedded += x_static_pos\n", "\n", " x_embedded = self.to_patching(x_embedded)\n", " static_embedded = self.to_patching(static_embedded)\n", "\n", " time_encoding = self.time_encoding(\n", " batch[\"input_time\"], batch[\"lead_time\"]\n", " )\n", "\n", " tokens = x_embedded + static_embedded + time_encoding\n", "\n", " # Now we generate masks based on masking_mode\n", " indices_masked, indices_unmasked = self.generate_mask(\n", " (batch_size, self._nglobal_mu)\n", " )\n", " indices_masked = indices_masked.to(device=tokens.device)\n", " indices_unmasked = indices_unmasked.to(device=tokens.device)\n", " maskdim: int = indices_masked.ndim\n", "\n", " # Unmasking\n", " unmask_view = (*indices_unmasked.shape, *[1] * (tokens.ndim - maskdim))\n", " unmasked = torch.gather(\n", " tokens,\n", " dim=maskdim - 1,\n", " index=indices_unmasked.view(*unmask_view).expand(\n", " *indices_unmasked.shape, *tokens.shape[maskdim:]\n", " ),\n", " )\n", "\n", " # Encoder\n", " x_encoded = self.encoder(unmasked)\n", "\n", " # Generate and position encode the mask tokens\n", " # [1, 1, 1, embed_dim]\n", " # -> [batch, global_seq_masked, local seq, embed_dim]\n", " mask_view = (*indices_masked.shape, *[1] * (tokens.ndim - maskdim))\n", " masking = self.mask_token.repeat(*static_embedded.shape[:3], 1)\n", " masked = masking + static_embedded\n", " masked = torch.gather(\n", " masked,\n", " dim=maskdim - 1,\n", " index=indices_masked.view(*mask_view).expand(\n", " *indices_masked.shape, *tokens.shape[maskdim:]\n", " ),\n", " )\n", "\n", " recon, _ = self.reconstruct_batch(\n", " indices_masked, indices_unmasked, masked, x_encoded\n", " )\n", "\n", " x_decoded = self.decoder(recon)\n", "\n", " # Output: [batch, global sequence, local sequence,\n", " # in_channels * patch_size[0] * patch_size[1]]\n", " x_unembed = self.unembed(x_decoded)\n", "\n", " # Reshape to [batch, global_lat, global_lon, local_lat, local_lon,\n", " # in_channels * patch_size[0] * patch_size[1]]\n", " x_out = self.from_patching(x_unembed)\n", "\n", " # Pixel shuffle to [batch, in_channels, lat, lon]\n", " x_out = F.pixel_shuffle(x_out, self.patch_size_px[0])\n", "\n", " if self.residual == \"temporal\":\n", " x_out = self.output_scalers * x_out + x_hat\n", " elif self.residual == \"climate\":\n", " x_out = self.output_scalers * x_out + batch[\"climate\"]\n", " elif self.residual == \"none\":\n", " x_out = (\n", " self.output_scalers * x_out\n", " + self.input_scalers_mu.reshape(1, -1, 1, 1)\n", " )\n", "\n", " return x_out\n" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "ename": "", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", "\u001b[1;31mClick here for more info. \n", "\u001b[1;31mView Jupyter log for further details." ] } ], "source": [ "import yaml\n", "\n", "# from PrithviWxC.model import PrithviWxC\n", "\n", "with open(\"./config.yaml\", \"r\") as f:\n", " config = yaml.safe_load(f)\n", "\n", "model = PrithviWxC(\n", " in_channels=config[\"params\"][\"in_channels\"],\n", " input_size_time=config[\"params\"][\"input_size_time\"],\n", " in_channels_static=config[\"params\"][\"in_channels_static\"],\n", " input_scalers_mu=in_mu,\n", " input_scalers_sigma=in_sig,\n", " input_scalers_epsilon=config[\"params\"][\"input_scalers_epsilon\"],\n", " static_input_scalers_mu=static_mu,\n", " static_input_scalers_sigma=static_sig,\n", " static_input_scalers_epsilon=config[\"params\"][\n", " \"static_input_scalers_epsilon\"\n", " ],\n", " output_scalers=output_sig**0.5,\n", " n_lats_px=config[\"params\"][\"n_lats_px\"],\n", " n_lons_px=config[\"params\"][\"n_lons_px\"],\n", " patch_size_px=config[\"params\"][\"patch_size_px\"],\n", " mask_unit_size_px=config[\"params\"][\"mask_unit_size_px\"],\n", " mask_ratio_inputs=masking_ratio,\n", " embed_dim=config[\"params\"][\"embed_dim\"],\n", " n_blocks_encoder=config[\"params\"][\"n_blocks_encoder\"],\n", " n_blocks_decoder=config[\"params\"][\"n_blocks_decoder\"],\n", " mlp_multiplier=config[\"params\"][\"mlp_multiplier\"],\n", " n_heads=config[\"params\"][\"n_heads\"],\n", " dropout=config[\"params\"][\"dropout\"],\n", " drop_path=config[\"params\"][\"drop_path\"],\n", " parameter_dropout=config[\"params\"][\"parameter_dropout\"],\n", " residual=residual,\n", " masking_mode=masking_mode,\n", " decoder_shifting=decoder_shifting,\n", " positional_encoding=positional_encoding,\n", " checkpoint_encoder=[],\n", " checkpoint_decoder=[],\n", ")\n", "\n", "\n", "state_dict = torch.load(weights_path, weights_only=False)\n", "if \"model_state\" in state_dict:\n", " state_dict = state_dict[\"model_state\"]\n", "model.load_state_dict(state_dict, strict=True)\n", "\n", "if (hasattr(model, \"device\") and model.device != device) or not hasattr(\n", " model, \"device\"\n", "):\n", " model = model.to(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Rollout\n", "We are now ready to perform the rollout. Agin the data has to be run through a\n", "preprocessor. However this time we use a preprocessor that can handle the\n", "additional intermediate data. Also, rather than calling the model directly, we\n", "have a conveient wrapper function that performs the interation. This also\n", "simplifies the model loading when using a sharded cahckpoint. If you attempt to\n", "perform training steps upton this function, we should use an aggressive number\n", "of activation checkpoints as the memory consumption becomes quite high." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import Tensor, nn\n", "\n", "\n", "def rollout_iter(\n", " nsteps: int,\n", " model: nn.Module,\n", " batch: dict[str, Tensor | int | float],\n", ") -> Tensor:\n", " \"\"\"A helper function for performing autoregressive rollout.\n", "\n", " Args:\n", " nsteps (int): The number of rollout steps to take\n", " model (nn.Module): A model.\n", " batch (dict): A data dictionary common to the Prithvi models.\n", "\n", " Raises:\n", " ValueError: If the number of steps isn't positive.\n", "\n", " Returns:\n", " Tensor: the output of the model after nsteps autoregressive iterations.\n", " \"\"\"\n", " if nsteps < 1:\n", " raise ValueError(\"'nsteps' shouold be a positive int.\")\n", "\n", " xlast = batch[\"x\"][:, 1]\n", " batch[\"lead_time\"] = batch[\"lead_time\"][..., 0]\n", "\n", " # Save the masking ratio to be restored later\n", " mask_ratio_tmp = model.mask_ratio_inputs\n", "\n", " for step in range(nsteps):\n", " # After first step, turn off masking\n", " if step > 0:\n", " model.mask_ratio_inputs = 0.0\n", "\n", " batch[\"static\"] = batch[\"statics\"][:, step]\n", " batch[\"climate\"] = batch[\"climates\"][:, step]\n", " batch[\"y\"] = batch[\"ys\"][:, step]\n", "\n", " out = model(batch)\n", "\n", " batch[\"x\"] = torch.cat((xlast[:, None], out[:, None]), dim=1)\n", " xlast = out\n", "\n", " # Restore the masking ratio\n", " model.mask_ratio_inputs = mask_ratio_tmp\n", "\n", " return xlast\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# from PrithviWxC.dataloaders.merra2_rollout import preproc\n", "# from PrithviWxC.rollout import rollout_iter\n", "\n", "data = next(iter(dataset))\n", "batch = preproc([data], padding)\n", "\n", "for k, v in batch.items():\n", " if isinstance(v, torch.Tensor):\n", " batch[k] = v.to(device)\n", "\n", "rng_state_1 = torch.get_rng_state()\n", "with torch.no_grad():\n", " model.eval()\n", " out = rollout_iter(dataset.nsteps, model, batch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "t2m = out[0, 12].cpu().numpy()\n", "\n", "lat = np.linspace(-90, 90, out.shape[-2])\n", "lon = np.linspace(-180, 180, out.shape[-1])\n", "X, Y = np.meshgrid(lon, lat)\n", "\n", "plt.contourf(X, Y, t2m, 100)\n", "plt.gca().set_aspect(\"equal\")\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 2 }