diff --git a/weather_mv/README.md b/weather_mv/README.md index c36530d9..48149528 100644 --- a/weather_mv/README.md +++ b/weather_mv/README.md @@ -408,6 +408,10 @@ _Command options_: takes extra time in COG creation. Default:False. * `--use_metrics`: A flag that allows you to add Beam metrics to the pipeline. Default: False. * `--use_monitoring_metrics`: A flag that allows you to to add Google Cloud Monitoring metrics to the pipeline. Default: False. +* `--partition_dims`: If the dataset contains other dimensions apart from latitude and longitude, partition the dataset into multiple datasets based on these dimensions. A separate COG file will be created for each partition and ingested into Earth Engine. Any unspecified dimensions will be flattened in the resulting COG. +* `--asset_name_format`: The asset name format for each partitioned COG file. This should contain the dimensions no other than partition_dims (along with init_time and valid_time). The dimension names should be enclosed in {} (e.g. a valid format is {init_time}_{valid_time}_{number}) +* `--forecast_dim_mapping`: A JSON string containing init_time and valid_time as keys and corresponding dimension names for each key. It is required if init_time or valid_time is used in asset_name_format. +* `--date_format`: A string containing datetime.strftime codes. It is used if the dimension mentioned in asset_name_format is a datetime. Default: %Y%m%d%H%M Invoke with `ee -h` or `earthengine --help` to see the full range of options. @@ -483,6 +487,28 @@ weather-mv ee --uris "gs://your-bucket/*.grib" \ --temp_location "gs://$BUCKET/tmp" ``` +Create separate COG files for every value of time and number dimensions: + +```bash +weather-mv ee --uris "gs://your-bucket/*.grib" \ + --asset_location "gs://$BUCKET/assets" \ # Needed to store assets generated from *.grib + --ee_asset "projects/$PROJECT/assets/test_dir" \ + --partition_dims time number + --asset_name_format "{time}_{number}" +``` + +Create COG files with name of init_time and valid_time (in datetime format) with 'YYYYMMDD' format: + +```bash +weather-mv ee --uris "gs://your-bucket/*.grib" \ + --asset_location "gs://$BUCKET/assets" \ # Needed to store assets generated from *.grib + --ee_asset "projects/$PROJECT/assets/test_dir" \ + --partition_dims time step number \ # step is in timedelta + --forecast_dim_mapping '{"init_time": "time", "valid_time": "step"}' + --asset_name_format "{init_time}_{valid_time}_{number}" + --date_format "%Y%m%D" +``` + Limit EE requests: ```bash diff --git a/weather_mv/loader_pipeline/ee.py b/weather_mv/loader_pipeline/ee.py index ef9389ba..a2b7c234 100644 --- a/weather_mv/loader_pipeline/ee.py +++ b/weather_mv/loader_pipeline/ee.py @@ -14,6 +14,7 @@ import argparse import csv import dataclasses +import itertools import json import logging import math @@ -29,6 +30,7 @@ import apache_beam as beam import ee import numpy as np +import xarray as xr from apache_beam.io.filesystems import FileSystems from apache_beam.metrics import metric from apache_beam.io.gcp.gcsio import WRITE_CHUNK_SIZE @@ -39,8 +41,15 @@ from google.auth.transport.requests import AuthorizedSession from rasterio.io import MemoryFile -from .sinks import ToDataSink, open_dataset, open_local, KwargsFactoryMixin, upload -from .util import make_attrs_ee_compatible, RateLimit, validate_region, get_utc_timestamp +from .sinks import ToDataSink, open_dataset, open_local, KwargsFactoryMixin, upload, _to_utc_timestring +from .util import ( + make_attrs_ee_compatible, + RateLimit, + validate_region, + get_utc_timestamp, + get_dims_from_name_format, + convert_to_string +) from .metrics import timeit, AddTimer, AddMetrics logger = logging.getLogger(__name__) @@ -132,6 +141,139 @@ def ee_initialize(use_personal_account: bool = False, ee.Initialize(creds) +def construct_asset_name(attrs: t.Dict, asset_name_format: str) -> str: + """Generate asset_name based on the format by using dataset attributes.""" + dims = get_dims_from_name_format(asset_name_format) + dim_values = {} + + # Get the init_time and valid_time from normal key and get other dimensions + # values from '_value' keys of attributes. + for dim in dims: + dim_values[dim] = attrs[dim + '_value'] if dim + '_value' in attrs else attrs[dim] + + asset_name = asset_name_format.format(**dim_values) + return asset_name + + +def add_additional_attrs(ds: xr.Dataset, forecast_dim_mapping: t.Dict[str, str], date_format: str) -> t.Dict: + """ + Adds additional attributes (start_time, end_time, forecast_seconds) in the dataset + if the forecast_dim_mapping is provided. + """ + attrs = {} + if (forecast_dim_mapping['init_time'] not in ds) or (forecast_dim_mapping['valid_time'] not in ds): + raise ValueError('The dimension passed for init_time/valid_time is not present in dataset.') + + init_time_da, valid_time_da = ds[forecast_dim_mapping['init_time']], ds[forecast_dim_mapping['valid_time']] + start_time = init_time_da.values + + if isinstance(valid_time_da.values, np.timedelta64): + end_time = start_time + valid_time_da.values + elif isinstance(valid_time_da.values, np.datetime64): + end_time = valid_time_da.values + else: + end_time = start_time + np.timedelta64(valid_time_da.values, 'h') + + attrs['forecast_seconds'] = int((end_time - start_time) / np.timedelta64(1, 's')) + attrs['start_time'] = _to_utc_timestring(start_time) + attrs['end_time'] = _to_utc_timestring(end_time) + + attrs['init_time'] = convert_to_string(start_time, date_format) + attrs['valid_time'] = convert_to_string(end_time, date_format) + + return attrs + + +def partition_dataset(ds: xr.Dataset, + partition_dims: t.List[str], + forecast_dim_mapping: t.Dict[str, str], + asset_name_format: str, + date_format: str) -> t.List[xr.Dataset]: + """ + Partitions a dataset based on the specified dimensions and flattens other dimensions into variable names. + + Args: + ds (xr.Dataset): Input xarray dataset. + partition_dims (list): List of dimensions to partition by (e.g., ['time', 'step']). + forecast_dim_mapping (dict): Dictionary containing init_time and valid_time as keys. This is used to add + start_time, end_time and forecast_seconds attributes in dataset. + It also helps to calculate the valid_time if the step value is in timedelta format. + asset_name_format (str): Specifies the format for the asset name of resulting COG, + containing dimensions enclosed in {} (e.g., '{init_time}_{valid_time}') + date_format (str): Datetime format to use in the asset name if the dimension is of type datetime + + Returns: + list: A list of partitioned xarray datasets. + """ + # Ensure partition_dims are valid dimensions + all_dims = list(ds.dims.keys()) + for dim in partition_dims: + if dim not in all_dims: + raise ValueError(f"Dimension '{dim}' is not present in the dataset.") + + # Dimensions to flatten (all except latitude, longitude, and partition_dims) + to_flatten = [dim for dim in all_dims if dim not in partition_dims + ['latitude', 'longitude']] + + partition_indices = itertools.product(*[range(ds.sizes[dim]) for dim in partition_dims]) + partitioned_datasets = [] + + for idx in partition_indices: + # Partition the dataset based on the partition_dims + selector = {dim: idx[i] for i, dim in enumerate(partition_dims)} + sliced_ds = ds.isel(selector) + + # Add attributes (init_time, valid_time, forecast_seconds) in dataset + sliced_ds.attrs.update(**add_additional_attrs(sliced_ds, forecast_dim_mapping, date_format)) + + # Flatten the remaining dimensions into variable names + new_data_vars = {} + for var_name, data_array in sliced_ds.data_vars.items(): + # Iterate over the indexes created for the dimension that needs to be flatten + for flat_idx in itertools.product(*[range(sliced_ds.sizes[dim]) for dim in to_flatten]): + flat_selector = {dim: flat_idx[i] for i, dim in enumerate(to_flatten)} + flat_data_array = data_array.isel(flat_selector).squeeze() + + # Construct the variable name with the values of flattened dimensions + parts = [] + for dim in to_flatten: + value = sliced_ds[dim].values[flat_selector[dim]] + parts.append(f"{dim}_{convert_to_string(value, date_format, make_ee_safe=True)}") + parts.append(var_name) + + flat_var_name = f"{'_'.join(parts)}" + new_data_vars[flat_var_name] = xr.DataArray( + flat_data_array.values, + dims=['latitude', 'longitude'], + coords={'latitude': flat_data_array.latitude, 'longitude': flat_data_array.longitude}, + attrs=flat_data_array.attrs + ) + + # Create the new dataset with only latitude and longitude dimensions + new_ds = xr.Dataset( + new_data_vars, + coords={ + 'latitude': sliced_ds.latitude, + 'longitude': sliced_ds.longitude + }, + attrs=sliced_ds.attrs + ) + + # Add the values of partitioned dimensions as attributes + new_ds.attrs.update( + **{ + f"{dim}_value": convert_to_string(sliced_ds[dim].values, date_format) + for dim in partition_dims + } + ) + + # Create asset_name and store it in attributes + asset_name = construct_asset_name(new_ds.attrs, asset_name_format) + new_ds.attrs['asset_name'] = asset_name + partitioned_datasets.append(new_ds) + + return partitioned_datasets + + class SetupEarthEngine(RateLimit): """A base class to setup the earth engine.""" @@ -254,6 +396,10 @@ class ToEarthEngine(ToDataSink): use_metrics: bool use_monitoring_metrics: bool topic: str + partition_dims: t.List[str] + asset_name_format: str + forecast_dim_mapping: t.Dict[str, str] + date_format: str # Pipeline arguments. job_name: str project: str @@ -306,6 +452,25 @@ def add_parser_arguments(cls, subparser: argparse.ArgumentParser): help='If you want to add Beam metrics to your pipeline. Default: False') subparser.add_argument('--use_monitoring_metrics', action='store_true', default=False, help='If you want to add GCP Monitoring metrics to your pipeline. Default: False') + subparser.add_argument('--partition_dims', nargs='*', default=None, + help='If the dataset contains other dimensions apart from latitude and longitude,' + ' partition the dataset into multiple datasets based on these dimensions.' + ' A separate COG file will be created for each partition.' + ' Any unspecified dimensions will be flattened in the resulting COG.') + subparser.add_argument('--asset_name_format', type=str, default=None, + help='The asset name format for each partitioned COG file.' + ' This should contain the dimensions no other than partition_dims' + ' (Although you can add init_time and valid_time provided that' + ' you have given forecast_dim_mapping). The dimension names should be' + ' enclosed in {} (e.g. a valid format is {init_time}_{valid_time}_{number})') + subparser.add_argument('--forecast_dim_mapping', type=json.loads, default=None, + help='A JSON string containing init_time and valid_time as keys and ' + 'corresponding dimension names for each key.' + ' It is required if init_time or valid_time is used in asset_name_format.') + subparser.add_argument('--date_format', type=str, default='%Y%m%d%H%M', + help='A string containing datetime.strftime codes.' + ' It is used if the dimension mentioned in asset_name_format is a datetime.' + ' Default: %Y%m%d%H%M') @classmethod def validate_arguments(cls, known_args: argparse.Namespace, pipeline_args: t.List[str]) -> None: @@ -360,6 +525,39 @@ def validate_arguments(cls, known_args: argparse.Namespace, pipeline_args: t.Lis if bool(known_args.initialization_time_regex) ^ bool(known_args.forecast_time_regex): raise RuntimeError("Both --initialization_time_regex & --forecast_time_regex flags need to be present") + if bool(known_args.partition_dims) ^ bool(known_args.asset_name_format): + raise RuntimeError("Both --partition_dims & --asset_name_format flags need to be present") + + if not known_args.partition_dims and bool(known_args.forecast_dim_mapping): + raise RuntimeError("forecast_dim_mapping can only be specified when partition_dims are passed.") + + # Check whether forecast_dim_mapping contains both init_time and valid_time + if ( + known_args.forecast_dim_mapping + and not ( + 'init_time' in known_args.forecast_dim_mapping + and 'valid_time' in known_args.forecast_dim_mapping + ) + ): + raise RuntimeError('forecast_dim_mapping should contain both init_time and valid_time as keys.') + + # Perform the checks when partition_dims are specified. + if known_args.partition_dims and known_args.ee_asset_type != "IMAGE": + raise RuntimeError('partition_dims should be specified for "IMAGE" asset_type only.') + + # Check whether the asset name format contains valid dimensions. + if known_args.asset_name_format: + dims = get_dims_from_name_format(known_args.asset_name_format) + for dim in dims: + if dim not in known_args.partition_dims + ['init_time', 'valid_time']: + raise RuntimeError('Only the dimensions used for partitioning can be used in the asset name.' + f'{dim} is not used to partition dataset.') + + if ('init_time' in dims or 'valid_time' in dims) ^ bool(known_args.forecast_dim_mapping): + raise RuntimeError('If asset_name_format contains init_time or valid_time, then forecast_dim_mapping' + 'is required. Conversely, if forecast_dim_mapping is provided, asset_name_format ' + 'must include either init_time or valid_time or both.') + logger.info(f"Add metrics to pipeline: {known_args.use_metrics}") logger.info(f"Add Google Cloud Monitoring metrics to pipeline: {known_args.use_monitoring_metrics}") @@ -475,6 +673,10 @@ class ConvertToAsset(beam.DoFn, KwargsFactoryMixin): forecast_time_regex: t.Optional[str] = None use_deflate: t.Optional[bool] = False use_metrics: t.Optional[bool] = False + partition_dims: t.Optional[list] = None + asset_name_format: t.Optional[str] = None + forecast_dim_mapping: t.Optional[t.Dict] = None + date_format: str = '%Y%m%d%H%M' def add_to_queue(self, queue: Queue, item: t.Any): """Adds a new item to the queue. @@ -501,119 +703,131 @@ def convert_to_asset(self, queue: Queue, uri: str): if not isinstance(ds_list, list): ds_list = [ds_list] - for ds in ds_list: - attrs = ds.attrs - data = list(ds.values()) - asset_name = get_ee_safe_name(uri) - channel_names = [ - self.band_names_dict.get(da.name, da.name) if self.band_names_dict - else da.name for da in data - ] - - dtype, crs, transform = (attrs.pop(key) for key in ['dtype', 'crs', 'transform']) - # Adding job_start_time to properites. - attrs["job_start_time"] = job_start_time - # Make attrs EE ingestable. - attrs = make_attrs_ee_compatible(attrs) - start_time, end_time = (attrs.get(key) for key in ('start_time', 'end_time')) - - if self.group_common_hypercubes: - level, height = (attrs.pop(key) for key in ['level', 'height']) - safe_level_name = get_ee_safe_name(level) - asset_name = f'{asset_name}_{safe_level_name}' - - compression = 'lzw' - predictor = 'NO' - if self.use_deflate: - compression = 'deflate' - # Depending on dtype select predictor value. - # Predictor is a method of storing only the difference from the - # previous value instead of the actual value. - predictor = 2 if np.issubdtype(dtype, np.integer) else 3 - - # For tiff ingestions. - if self.ee_asset_type == 'IMAGE': - file_name = f'{asset_name}.tiff' - - with MemoryFile() as memfile: - with memfile.open(driver='COG', - dtype=dtype, - width=data[0].data.shape[1], - height=data[0].data.shape[0], - count=len(data), - nodata=np.nan, - crs=crs, - transform=transform, - compress=compression, - predictor=predictor) as f: - for i, da in enumerate(data): - f.write(da, i+1) - # Making the channel name EE-safe before adding it as a band name. - f.set_band_description(i+1, get_ee_safe_name(channel_names[i])) - f.update_tags(i+1, band_name=channel_names[i]) - f.update_tags(i+1, **da.attrs) - - # Write attributes as tags in tiff. - f.update_tags(**attrs) - - # Copy in-memory tiff to gcs. + for dataset in ds_list: + if self.partition_dims and self.ee_asset_type == 'IMAGE': + partitioned_datasets = partition_dataset( + dataset, + self.partition_dims, + self.forecast_dim_mapping, + self.asset_name_format, + self.date_format + ) + else: + partitioned_datasets = [dataset] + for ds in partitioned_datasets: + attrs = ds.attrs + data = list(ds.values()) + asset_name = attrs.pop('asset_name') if 'asset_name' in attrs else get_ee_safe_name(uri) + channel_names = [ + self.band_names_dict.get(da.name, da.name) if self.band_names_dict + else da.name for da in data + ] + + dtype, crs, transform = (attrs.pop(key) for key in ['dtype', 'crs', 'transform']) + # Adding job_start_time to properites. + attrs["job_start_time"] = job_start_time + # Make attrs EE ingestable. + attrs = make_attrs_ee_compatible(attrs) + start_time, end_time = (attrs.get(key) for key in ('start_time', 'end_time')) + + if self.group_common_hypercubes: + level, height = (attrs.pop(key) for key in ['level', 'height']) + safe_level_name = get_ee_safe_name(level) + asset_name = f'{asset_name}_{safe_level_name}' + + compression = 'lzw' + predictor = 'NO' + if self.use_deflate: + compression = 'deflate' + # Depending on dtype select predictor value. + # Predictor is a method of storing only the difference from the + # previous value instead of the actual value. + predictor = 2 if np.issubdtype(dtype, np.integer) else 3 + + # For tiff ingestions. + if self.ee_asset_type == 'IMAGE': + file_name = f'{asset_name}.tiff' + + with MemoryFile() as memfile: + with memfile.open(driver='COG', + dtype=dtype, + width=data[0].data.shape[1], + height=data[0].data.shape[0], + count=len(data), + nodata=np.nan, + crs=crs, + transform=transform, + compress=compression, + predictor=predictor) as f: + for i, da in enumerate(data): + f.write(da, i+1) + # Making the channel name EE-safe before adding it as a band name. + f.set_band_description(i+1, get_ee_safe_name(channel_names[i])) + f.update_tags(i+1, band_name=channel_names[i]) + f.update_tags(i+1, **da.attrs) + + # Write attributes as tags in tiff. + f.update_tags(**attrs) + + # Copy in-memory tiff to gcs. + target_path = os.path.join(self.asset_location, file_name) + with FileSystems().create(target_path) as dst: + shutil.copyfileobj(memfile, dst, WRITE_CHUNK_SIZE) + child_logger.info(f"Uploaded {uri!r}'s COG to {target_path}") + + # For feature collection ingestions. + elif self.ee_asset_type == 'TABLE': + channel_names = [] + file_name = f'{asset_name}.csv' + + shape = math.prod(list(ds.dims.values())) + # Names of dimesions, coordinates and data variables. + dims = list(ds.dims) + coords = [c for c in list(ds.coords) if c not in dims] + vars = list(ds.data_vars) + header = dims + coords + vars + + # Data of dimesions, coordinates and data variables. + dims_data = [ds[dim].data for dim in dims] + coords_data = [np.full((shape,), ds[coord].data) for coord in coords] + vars_data = [ds[var].data.flatten() for var in vars] + data = coords_data + vars_data + + dims_shape = [len(ds[dim].data) for dim in dims] + + def get_dims_data(index: int) -> t.List[t.Any]: + """Returns dimensions for the given flattened index.""" + return [ + dim[int(index / math.prod(dims_shape[i + 1 :])) % len(dim)] + for i, dim in enumerate(dims_data) + ] + + # Copy CSV to gcs. target_path = os.path.join(self.asset_location, file_name) - with FileSystems().create(target_path) as dst: - shutil.copyfileobj(memfile, dst, WRITE_CHUNK_SIZE) - child_logger.info(f"Uploaded {uri!r}'s COG to {target_path}") - - # For feature collection ingestions. - elif self.ee_asset_type == 'TABLE': - channel_names = [] - file_name = f'{asset_name}.csv' - - shape = math.prod(list(ds.dims.values())) - # Names of dimesions, coordinates and data variables. - dims = list(ds.dims) - coords = [c for c in list(ds.coords) if c not in dims] - vars = list(ds.data_vars) - header = dims + coords + vars - - # Data of dimesions, coordinates and data variables. - dims_data = [ds[dim].data for dim in dims] - coords_data = [np.full((shape,), ds[coord].data) for coord in coords] - vars_data = [ds[var].data.flatten() for var in vars] - data = coords_data + vars_data - - dims_shape = [len(ds[dim].data) for dim in dims] - - def get_dims_data(index: int) -> t.List[t.Any]: - """Returns dimensions for the given flattened index.""" - return [ - dim[int(index/math.prod(dims_shape[i+1:])) % len(dim)] for (i, dim) in enumerate(dims_data) - ] - - # Copy CSV to gcs. - target_path = os.path.join(self.asset_location, file_name) - with tempfile.NamedTemporaryFile() as temp: - with open(temp.name, 'w', newline='') as f: - writer = csv.writer(f) - writer.writerows([header]) - # Write rows in batches. - for i in range(0, shape, ROWS_PER_WRITE): - writer.writerows( - [get_dims_data(i) + list(row) for row in zip( - *[d[i:i + ROWS_PER_WRITE] for d in data] - )] - ) - - upload(temp.name, target_path) - - asset_data = AssetData( - name=asset_name, - target_path=target_path, - channel_names=channel_names, - start_time=start_time, - end_time=end_time, - properties=attrs - ) + with tempfile.NamedTemporaryFile() as temp: + with open(temp.name, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerows([header]) + # Write rows in batches. + for i in range(0, shape, ROWS_PER_WRITE): + writer.writerows( + [get_dims_data(i) + list(row) for row in zip( + *[d[i:i + ROWS_PER_WRITE] for d in data] + )] + ) + + upload(temp.name, target_path) + + asset_data = AssetData( + name=asset_name, + target_path=target_path, + channel_names=channel_names, + start_time=start_time, + end_time=end_time, + properties=attrs + ) - self.add_to_queue(queue, asset_data) + self.add_to_queue(queue, asset_data) self.add_to_queue(queue, None) # Indicates end of the subprocess. @timeit('ConvertToAsset') diff --git a/weather_mv/loader_pipeline/ee_test.py b/weather_mv/loader_pipeline/ee_test.py index 7627152e..36726e73 100644 --- a/weather_mv/loader_pipeline/ee_test.py +++ b/weather_mv/loader_pipeline/ee_test.py @@ -11,14 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import fnmatch import logging import os import tempfile import unittest +import xarray as xr +import numpy as np from .ee import ( get_ee_safe_name, - ConvertToAsset + ConvertToAsset, + add_additional_attrs, + partition_dataset, + construct_asset_name ) from .sinks_test import TestDataBase @@ -107,6 +113,104 @@ def test_convert_to_table_asset__with_multiple_grib_edition(self): # The size of tiff is expected to be more than grib. self.assertTrue(os.path.getsize(asset_path) > os.path.getsize(data_path)) + def test_convert_to_multi_image_asset(self): + convert_to_multi_image_asset = ConvertToAsset( + asset_location=self.tmpdir.name, + partition_dims=['time', 'step'], + asset_name_format='{init_time}_{valid_time}', + forecast_dim_mapping={ + 'init_time': 'time', + 'valid_time': 'step' + } + ) + data_path = f'{self.test_data_folder}/test_data_multi_dimension.nc' + asset_path = os.path.join(self.tmpdir.name) + list(convert_to_multi_image_asset.process(data_path)) + + # Make sure there are total 9 tiff files generated at target location + self.assertEqual(len(fnmatch.filter(os.listdir(asset_path), '*.tiff')), 9) + + +class PartitionDatasetTests(TestDataBase): + + def setUp(self): + super().setUp() + self.forecast_dim_mapping = { + 'init_time': 'time', + 'valid_time': 'step' + } + self.date_format = '%Y%m%d%H%M' + self.ds = xr.open_dataset(f'{self.test_data_folder}/test_data_multi_dimension.nc') + + def test_add_additional_attrs(self): + + sliced_ds = self.ds.isel({'time': 0, 'step': 1}) + attrs = add_additional_attrs(sliced_ds, self.forecast_dim_mapping, self.date_format) + attr_names = ['init_time', 'valid_time', 'start_time', 'end_time', 'forecast_seconds'] + + for name in attr_names: + self.assertTrue(name in attrs) + + self.assertEqual(attrs['init_time'], '202412010000') + self.assertEqual(attrs['valid_time'], '202412010600') + self.assertEqual(attrs['start_time'], '2024-12-01T00:00:00Z') + self.assertEqual(attrs['end_time'], '2024-12-01T06:00:00Z') + self.assertEqual(attrs['forecast_seconds'], 6 * 60 * 60) + + def test_construct_asset_name(self): + + asset_name_format = '{init_time}_{valid_time}_{level}' + attrs = { + 'init_time': '202412010000', + 'valid_time': '202412010600', + 'level_value': np.array(2), + 'time_value': np.datetime64('2024-12-01'), + 'step_value': np.timedelta64(6, 'h') + } + + self.assertEqual(construct_asset_name(attrs, asset_name_format), '202412010000_202412010600_2') + + def test_partition_dataset(self): + + partition_dims = ['time', 'step'] + asset_name_format = '{init_time}_{valid_time}' + partition_datasets = partition_dataset( + self.ds, partition_dims, self.forecast_dim_mapping, asset_name_format, self.date_format + ) + + # As the ds partitioned on time(3) and step(3), there should be total 9 datasets + self.assertEqual(len(partition_datasets), 9) + + dates = ['202412010000', '202412020000', '202412030000'] + valid_times = [ + '202412010000', '202412010600', '202412011200', + '202412020000', '202412020600', '202412021200', + '202412030000', '202412030600', '202412031200', + ] + + for i, dataset in enumerate(partition_datasets): + + # Make sure the level dimension is flattened and values added in dataArray name + self.assertEqual(len(dataset.data_vars), 6) + self.assertTrue('level_0_x' in dataset.data_vars) + self.assertTrue('level_0_y' in dataset.data_vars) + self.assertTrue('level_0_z' in dataset.data_vars) + self.assertTrue('level_1_x' in dataset.data_vars) + self.assertTrue('level_1_y' in dataset.data_vars) + self.assertTrue('level_1_z' in dataset.data_vars) + + # Make sure the dataset contains only 2 dimension: latitude and longitude + self.assertEqual(len(dataset.dims), 2) + self.assertTrue('latitude' in dataset.dims.keys()) + self.assertTrue('longitude' in dataset.dims.keys()) + + # Make sure the data arrays are of 2D + self.assertEqual(dataset['level_0_x'].shape, (18, 36)) + + # Make sure the dataset have correct name + self.assertTrue('asset_name' in dataset.attrs) + self.assertEqual(dataset.attrs['asset_name'], f'{dates[i // 3]}_{valid_times[i]}') + if __name__ == '__main__': unittest.main() diff --git a/weather_mv/loader_pipeline/sinks.py b/weather_mv/loader_pipeline/sinks.py index 553bb121..2b71d25d 100644 --- a/weather_mv/loader_pipeline/sinks.py +++ b/weather_mv/loader_pipeline/sinks.py @@ -226,7 +226,7 @@ def _replace_dataarray_names_with_long_names(ds: xr.Dataset): def _to_utc_timestring(np_time: np.datetime64) -> str: """Turn a numpy datetime64 into UTC timestring.""" timestamp = float((np_time - np.datetime64(0, 's')) / np.timedelta64(1, 's')) - return datetime.datetime.utcfromtimestamp(timestamp).strftime('%Y-%m-%dT%H:%M:%SZ') + return datetime.datetime.fromtimestamp(timestamp, datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ') def _add_is_normalized_attr(ds: xr.Dataset, value: bool) -> xr.Dataset: diff --git a/weather_mv/loader_pipeline/util.py b/weather_mv/loader_pipeline/util.py index 16ca435c..744e8a83 100644 --- a/weather_mv/loader_pipeline/util.py +++ b/weather_mv/loader_pipeline/util.py @@ -27,6 +27,7 @@ import typing as t import uuid from functools import partial +from string import Formatter from urllib.parse import urlparse import apache_beam as beam import numpy as np @@ -327,6 +328,38 @@ def validate_region(output_table: t.Optional[str] = None, signal.signal(signal.SIGINT, original_sigtstp_handler) +def get_dims_from_name_format(asset_name_format: str) -> t.List[str]: + """Returns a list of dimension from the asset name format.""" + return [field_name for _, field_name, _, _ in Formatter().parse(asset_name_format) if field_name] + + +def get_datetime_from(value: np.datetime64) -> datetime.datetime: + return datetime.datetime.fromtimestamp( + (value - np.datetime64(0, 's')) // np.timedelta64(1, 's'), + datetime.timezone.utc + ) + + +def convert_to_string(value: t.Any, date_format : str = '%Y%m%d%H%M', make_ee_safe : bool = False): + """Converts a given value to string based on the type of value.""" + def _make_ee_safe(str_val: str) -> str: + return re.sub(r'[^a-zA-Z0-9-_]+', r'_', str_val) + + str_val = '' + if isinstance(value, np.ndarray) and value.size == 1: + value = value.item() + + if isinstance(value, float): + str_val = str(round(value, 2)) + elif isinstance(value, np.datetime64): + dt = get_datetime_from(value) + str_val = dt.strftime(date_format) + else: + str_val = str(value) + + return _make_ee_safe(str_val) if make_ee_safe else str_val + + def _shard(elem, num_shards: int): return (np.random.randint(0, num_shards), elem) diff --git a/weather_mv/loader_pipeline/util_test.py b/weather_mv/loader_pipeline/util_test.py index 61b4a68d..956769cf 100644 --- a/weather_mv/loader_pipeline/util_test.py +++ b/weather_mv/loader_pipeline/util_test.py @@ -26,6 +26,7 @@ ichunked, make_attrs_ee_compatible, to_json_serializable_type, + convert_to_string ) @@ -251,3 +252,16 @@ def test_to_json_serializable_type_datetime(self): self.assertEqual(self._convert(timedelta(seconds=1)), float(1)) self.assertEqual(self._convert(timedelta(minutes=1)), float(60)) self.assertEqual(self._convert(timedelta(days=1)), float(86400)) + + +class ConvertToStringTests(unittest.TestCase): + + def test_convert_scalar_to_string(self): + self.assertEqual(convert_to_string(np.array(5)), '5') + self.assertEqual(convert_to_string(np.array(5.6789)), '5.68') + + def test_convert_datetime_to_string(self): + value = np.datetime64('2025-01-24T04:05:06') + self.assertEqual(convert_to_string(value), '202501240405') + self.assertEqual(convert_to_string(value, '%Y/%m/%dT%H:%M:%S'), '2025/01/24T04:05:06') + self.assertEqual(convert_to_string(value, '%Y/%m/%dT%H:%M:%S', True), '2025_01_24T04_05_06') diff --git a/weather_mv/test_data/test_data_multi_dimension.nc b/weather_mv/test_data/test_data_multi_dimension.nc new file mode 100644 index 00000000..3df88fce Binary files /dev/null and b/weather_mv/test_data/test_data_multi_dimension.nc differ