From 6e1f50fcaf87c871bce4641c87f7cdb32e7df8da Mon Sep 17 00:00:00 2001 From: Tobias Schmidt <13055656+schmidtnz@users.noreply.github.com> Date: Mon, 10 Feb 2025 15:58:32 +1300 Subject: [PATCH] refactor: config dataclass and more helper functions for standardising.py --- scripts/gdal/gdal_commands.py | 2 +- scripts/standardise_validate.py | 27 ++- scripts/standardising.py | 360 +++++++++++++++++++------------- 3 files changed, 231 insertions(+), 158 deletions(-) diff --git a/scripts/gdal/gdal_commands.py b/scripts/gdal/gdal_commands.py index ae323ad8..9cdbe4fa 100644 --- a/scripts/gdal/gdal_commands.py +++ b/scripts/gdal/gdal_commands.py @@ -121,7 +121,7 @@ def get_alpha_command() -> list[str]: ] -def get_transform_srs_command(source_epsg: str, target_epsg: str) -> list[str]: +def get_transform_srs_command(source_epsg: int, target_epsg: int) -> list[str]: """Get a `gdalwarp` command to transform the srs. Args: diff --git a/scripts/standardise_validate.py b/scripts/standardise_validate.py index 833787c9..6aba3607 100644 --- a/scripts/standardise_validate.py +++ b/scripts/standardise_validate.py @@ -14,7 +14,7 @@ from scripts.gdal.gdal_helper import get_srs, get_vfs_path from scripts.json_codec import dict_to_json_bytes from scripts.stac.imagery.create_stac import create_item -from scripts.standardising import run_standardising +from scripts.standardising import StandardisingConfig, run_standardising def str_to_bool(value: str) -> bool: @@ -30,7 +30,7 @@ def str_to_list_or_none(values: str) -> list[Decimal] | None: return None result = [Decimal(val) for val in values.split(",")] if len(result) != 2: - raise argparse.ArgumentTypeError(f"Invalid list (must be blank or exactly 2 values): {values}") + raise argparse.ArgumentTypeError(f"Invalid list - must be blank or exactly 2 values x,y. Received: {values}") return result @@ -49,10 +49,13 @@ def parse_args() -> argparse.Namespace: ), required=False, ) - parser.add_argument("--source-epsg", dest="source_epsg", required=True, help="The EPSG code of the source imagery") + parser.add_argument( + "--source-epsg", dest="source_epsg", type=int, required=True, help="The EPSG code of the source imagery" + ) parser.add_argument( "--target-epsg", dest="target_epsg", + type=int, required=True, help="The target EPSG code. If different to source the imagery will be reprojected", ) @@ -124,6 +127,16 @@ def report_non_visual_qa_errors(file: FileTiff) -> None: def main() -> None: arguments = parse_args() + standardising_config = StandardisingConfig( + gdal_preset=arguments.preset, + source_epsg=arguments.source_epsg, + target_epsg=arguments.target_epsg, + gsd=arguments.gsd, + create_footprints=arguments.create_footprints, + cutline=arguments.cutline, + scale_to_resolution=arguments.scale_to_resolution, + ) + try: tile_files = load_input_files(arguments.from_file) except InputParameterError as e: @@ -149,16 +162,10 @@ def main() -> None: tiff_files = run_standardising( tile_files, - arguments.preset, - arguments.cutline, + standardising_config, concurrency, - arguments.source_epsg, - arguments.target_epsg, - arguments.gsd, - arguments.create_footprints, gdal_version, arguments.target, - arguments.scale_to_resolution, ) if len(tiff_files) == 0: diff --git a/scripts/standardising.py b/scripts/standardising.py index 4bc5d020..394c5b67 100644 --- a/scripts/standardising.py +++ b/scripts/standardising.py @@ -2,6 +2,7 @@ import os import tempfile +from dataclasses import dataclass from decimal import Decimal from functools import partial from multiprocessing import Pool @@ -28,32 +29,53 @@ from scripts.tile.tile_index import Bounds, get_bounds_from_name +@dataclass +class StandardisingConfig: + """Standardising configuration. + gdal_preset: gdal preset to use. See `gdal.gdal_preset.py` + source_epsg: current EPSG code of the source file + target_epsg: desired EPSG code of the output file + gsd: expected Ground Sample Distance in meters + cutline: path to the cutline file. Must be `.fgb` or `.geojson` + scale_to_resolution: scale TIFFs to the specified x,y resolution. Defaults to None = no scaling. + """ + + gdal_preset: str + source_epsg: int + target_epsg: int + gsd: Decimal + create_footprints: bool + cutline: str | None + scale_to_resolution: list[Decimal] | None = None + + def __post_init__(self) -> None: + if self.cutline is not None and not self.cutline.endswith((".fgb", ".geojson")): + raise ValueError(f"Only .fgb or .geojson cutlines are supported: {self.cutline}") + if self.scale_to_resolution is not None and len(self.scale_to_resolution) != 2: + raise ValueError(f"Scale_to_resolution must be exactly two items [xres, yres]: {self.scale_to_resolution}") + + def run_standardising( - todo: list[TileFiles], - preset: str, - cutline: str | None, + files_to_process: list[TileFiles], + standardising_config: StandardisingConfig, concurrency: int, - source_epsg: str, - target_epsg: str, - gsd: Decimal, - create_footprints: bool, gdal_version: str, target_output: str = "/tmp/", - scale_to_resolution: list[Decimal] | None = None, ) -> list[FileTiff]: """Run `standardising()` in parallel (`concurrency`). Args: - todo: list of TileFiles (tile name and input files) to standardise - preset: gdal preset to use. See `gdal.gdal_preset.py` - cutline: path to the cutline. Must be `.fgb` or `.geojson` + files_to_process: list of TileFiles (tile name and input files) to standardise + standardising_config: a StandardisingConfig dictionary, containing + gdal_preset: gdal preset to use. See `gdal.gdal_preset.py` + source_epsg: current EPSG code of the source file + target_epsg: desired EPSG code of the output file + gsd: expected Ground Sample Distance in meters + cutline: path to the cutline file. Must be `.fgb` or `.geojson` + scale_to_resolution: scale TIFFs to the specified x,y resolution. Defaults to None = no scaling. concurrency: number of concurrent files to process - source_epsg: EPSG code of the source file - target_epsg: EPSG code of reprojection - gsd: Ground Sample Distance in meters gdal_version: version of GDAL used for standardising target_output: output directory path. Defaults to "/tmp/" - scale_to_resolution: scale TIFFs to the specified x,y resolution. Defaults to None = no scaling. Returns: @@ -62,7 +84,7 @@ def run_standardising( # pylint: disable-msg=too-many-arguments start_time = time_in_ms() - get_log().info("standardising_start", gdalVersion=gdal_version, fileCount=len(todo)) + get_log().info("standardising_start", gdalVersion=gdal_version, fileCount=len(files_to_process)) with Pool(concurrency) as p: standardized_tiffs = [ @@ -70,16 +92,10 @@ def run_standardising( for entry in p.map( partial( standardising, - preset=preset, - source_epsg=source_epsg, - target_epsg=target_epsg, + config=standardising_config, target_output=target_output, - gsd=gsd, - create_footprints=create_footprints, - cutline=cutline, - scale_to_resolution=scale_to_resolution, ), - todo, + files_to_process, ) if entry is not None ] @@ -111,31 +127,27 @@ def create_vrt( return vrt_path -# pylint: disable-msg=too-many-locals -# pylint: disable-msg=too-many-statements -# pylint: disable-msg=too-many-arguments def standardising( files: TileFiles, - preset: str, - source_epsg: str, - target_epsg: str, - gsd: Decimal, - create_footprints: bool, - cutline: str | None, + config: StandardisingConfig, target_output: str = "/tmp/", - scale_to_resolution: list[Decimal] | None = None, ) -> FileTiff | None: - """Apply transformations using GDAL to the source file and create a footprint sidecar file. + """Standardise geospatial TIFF files using GDAL. + Optionally create a footprint sidecar file. Args: - files: paths to the files to standardise - preset: gdal preset to use. See `gdal.gdal_preset.py` - source_epsg: EPSG code of the source file - target_epsg: EPSG code of reprojection - gsd: Ground Sample Distance in meters - cutline: path to the cutline. Must be `.fgb` or `.geojson` - target_output: output directory path. Defaults to "/tmp/" - scale_to_resolution: scale TIFFs to the specified x,y resolution. Defaults to None = no scaling. + files: a TileFiles named tuple, containing + output: name of the output tile to be created + input: list of input files for the creation of the output tile + includeDerived: whether STAC should include derived_from links + config: a StandardisingConfig data class, containing + gdal_preset: gdal preset to use. See `gdal.gdal_preset.py` + source_epsg: current EPSG code of the source file + target_epsg: desired EPSG code of the output file + gsd: expected Ground Sample Distance in meters + cutline: path to the cutline file. Must be `.fgb` or `.geojson` + scale_to_resolution: scale TIFFs to the specified x,y resolution. Defaults to None = no scaling. + target_output: output directory path. Defaults to "/tmp/". Not to be confused with `tmp_path`. Raises: Exception: if cutline is not a .fgb or .geojson file @@ -143,118 +155,172 @@ def standardising( Returns: a FileTiff wrapper """ - standardized_file_name = files.output + ".tiff" - footprint_file_name = files.output + SUFFIX_FOOTPRINT - standardized_file_path = os.path.join(target_output, standardized_file_name) - footprint_file_path = os.path.join(target_output, footprint_file_name) - tiff = FileTiff(files.inputs, preset, files.includeDerived) - tiff.set_path_standardised(standardized_file_path) - - # Already proccessed can skip processing - if exists(standardized_file_path): - get_log().info("standardised_tiff_already_exists", path=standardized_file_path) + standardised_file_path = os.path.join(target_output, f"{files.output}.tiff") + tiff = FileTiff(files.inputs, config.gdal_preset, files.includeDerived) + tiff.set_path_standardised(standardised_file_path) + + # Skip processing if output file already exists + if exists(standardised_file_path): + get_log().info("standardised_tiff_already_exists", path=standardised_file_path) return tiff # Download any needed file from S3 ["/foo/bar.tiff", "s3://foo"] => "/tmp/bar.tiff", "/tmp/foo.tiff" with tempfile.TemporaryDirectory() as tmp_path: - standardized_working_path = os.path.join(tmp_path, standardized_file_name) - footprint_tmp_path = os.path.join(tmp_path, footprint_file_name) - sidecars: list[str] = [] - for extension in [".prj", ".tfw"]: - for file_input in tiff.get_paths_original(): - sidecars.append(f"{os.path.splitext(file_input)[0]}{extension}") + + # Handle sidecar files and copy source TIFFs to tmp_path + generate_prj_tfw_sidecars(tiff, f"{tmp_path}/source/") source_files = write_all(tiff.get_paths_original(), f"{tmp_path}/source/") - write_sidecars(sidecars, f"{tmp_path}/source/") - source_tiffs = [file for file in source_files if is_tiff(file)] + # Determine if VRT needs alpha + vrt_add_alpha = check_vrt_alpha(source_files) + + # Create base VRT file + current_working_file = create_vrt( + [source_file for source_file in source_files if is_tiff(source_file)], + tmp_path, + add_alpha=vrt_add_alpha, + resolution=config.scale_to_resolution, + ) + + # Apply cutline if needed + current_working_file = apply_cutline(current_working_file, config, tmp_path) + + # Add alpha band to imagery + current_working_file = add_alpha_to_imagery(current_working_file, tiff, tmp_path) + + # Reproject if needed + current_working_file = reproject_if_needed(current_working_file, config, tmp_path) + + # Generate output using GDAL + current_working_file = apply_gdal_transformation(current_working_file, config, tmp_path, tile_name=files.output) + + # Update GDAL info + tiff.get_gdalinfo(current_working_file) + + # Validate output and create footprints + if check_tiff_empty(current_working_file): + return None + + if config.create_footprints: + temp_footprint = create_footprint(current_working_file, config, tmp_path) + footprint_file_path = os.path.join(target_output, f"{files.output}{SUFFIX_FOOTPRINT}") + write(footprint_file_path, read(temp_footprint), content_type=ContentType.GEOJSON.value) + + # Copy the final version of the working / temp file to the desired destination + write(standardised_file_path, read(current_working_file), content_type=ContentType.GEOTIFF.value) + + return tiff + + +def generate_prj_tfw_sidecars(tiff: FileTiff, target_path: str) -> list[str]: + """Generate sidecar files (prj, tfw) for the TIFF files.""" + sidecars = [ + f"{os.path.splitext(file_input)[0]}{extension}" + for extension in [".prj", ".tfw"] + for file_input in tiff.get_paths_original() + ] + write_sidecars(sidecars, target_path) + return sidecars - vrt_add_alpha = True - for file in source_tiffs: - gdal_data = gdal_info(file) - bands = gdal_data["bands"] +def check_vrt_alpha(source_files: list[str]) -> bool: + """Check if alpha is needed in the VRT.""" + for file in source_files: + if is_tiff(file): + bands = gdal_info(file)["bands"] if (len(bands) == 4 and bands[3]["colorInterpretation"] == "Alpha") or ( len(bands) == 1 and bands[0]["colorInterpretation"] == "Gray" ): - vrt_add_alpha = False - - # Start from base VRT - input_file = create_vrt(source_tiffs, tmp_path, add_alpha=vrt_add_alpha, resolution=scale_to_resolution) - - # Apply cutline - if cutline: - input_cutline_path = cutline - if is_s3(cutline): - if not cutline.endswith((".fgb", ".geojson")): - raise Exception(f"Only .fgb or .geojson cutlines are support cutline:{cutline}") - input_cutline_path = os.path.join(tmp_path, "cutline" + os.path.splitext(cutline)[1]) - # Ensure the input cutline is a easy spot for GDAL to read - write(input_cutline_path, read(cutline)) - - target_vrt = os.path.join(tmp_path, "cutline.vrt") - run_gdal(get_cutline_command(input_cutline_path), input_file=input_file, output_file=target_vrt) - input_file = target_vrt - elif tiff.get_tiff_type() == FileTiffType.IMAGERY: - target_vrt = os.path.join(tmp_path, "target.vrt") - # add alpha band to all imagery for consistency allowing GDAL to run correctly (TDE-804) - run_gdal(get_alpha_command(), input_file=input_file, output_file=target_vrt) - input_file = target_vrt - - # Reproject tiff if needed - if source_epsg != target_epsg: - target_vrt = os.path.join(tmp_path, "reproject.vrt") - get_log().info("Reprojecting Tiff", path=input_file, sourceEPSG=source_epsg, targetEPSG=target_epsg) - run_gdal(get_transform_srs_command(source_epsg, target_epsg), input_file=input_file, output_file=target_vrt) - input_file = target_vrt - - transformed_image_gdalinfo = gdal_info(input_file) - command = get_gdal_command(preset, epsg=int(target_epsg)) - command.extend(get_gdal_band_offset(input_file, transformed_image_gdalinfo, preset)) - - # Specify the extent to get the right boundaries in case of the tiff got no data on its edges - output_bounds: Bounds = get_bounds_from_name(files.output) - min_x = output_bounds.point.x - max_y = output_bounds.point.y - min_y = max_y - output_bounds.size.height - max_x = min_x + output_bounds.size.width - command.extend(["-co", f"TARGET_SRS=EPSG:{target_epsg}", "-co", f"EXTENT={min_x},{min_y},{max_x},{max_y}"]) - - # Need GDAL to write to temporary location so no broken files end up in the done folder. - run_gdal(command, input_file=input_file, output_file=standardized_working_path) - - # Update the `FileTiff.gdalinfo` as the system has local access to the TIFF - get_log().debug("Saving gdalinfo from local standardised TIFF in FileTiff object", path=standardized_working_path) - tiff.get_gdalinfo(standardized_working_path) - - with TiffFile(standardized_working_path) as file_handle: - if any(tile_byte_count != 0 for tile_byte_count in file_handle.pages.first.tags["TileByteCounts"].value): - if create_footprints: - # Create footprint GeoJSON - run_gdal( - [ - "gdal_footprint", - "-t_srs", - f"EPSG:{EpsgNumber.WGS_1984.value}", - "-max_points", - "unlimited", - # Round to 8 decimal places to reduce unnecessary complexity - "-lco", - "COORDINATE_PRECISION=8", - "-simplify", - str(get_buffer_distance(gsd)), - ], - standardized_working_path, - footprint_tmp_path, - ) - write( - footprint_file_path, - read(footprint_tmp_path), - content_type=ContentType.GEOJSON.value, - ) - - write(standardized_file_path, read(standardized_working_path), content_type=ContentType.GEOTIFF.value) - - return tiff - - get_log().info("Skipping empty output image", path=input_file, sourceEPSG=source_epsg, targetEPSG=target_epsg) - return None + return False + return True + + +def apply_cutline(input_file: str, config: StandardisingConfig, tmp_path: str) -> str: + """Apply a cutline to the input VRT if a cutline is provided.""" + if config.cutline: + if not config.cutline.endswith((".fgb", ".geojson")): + raise ValueError(f"Only .fgb or .geojson cutlines are supported: {config.cutline}") + + input_cutline_path = config.cutline + if is_s3(config.cutline): + input_cutline_path = os.path.join(tmp_path, "cutline" + os.path.splitext(config.cutline)[1]) + write(input_cutline_path, read(config.cutline)) + + target_vrt = os.path.join(tmp_path, "cutline.vrt") + run_gdal(get_cutline_command(input_cutline_path), input_file=input_file, output_file=target_vrt) + return target_vrt + + return input_file + + +def add_alpha_to_imagery(input_file: str, tiff: FileTiff, tmp_path: str) -> str: + """Add alpha band to all imagery for consistency allowing GDAL to run correctly (TDE-804).""" + if tiff.get_tiff_type() == FileTiffType.IMAGERY: + target_vrt = os.path.join(tmp_path, "target.vrt") + get_log().info("Adding alpha band to TIFF", path=input_file, existingTiffBands=gdal_info(input_file)["bands"]) + run_gdal(get_alpha_command(), input_file=input_file, output_file=target_vrt) + return target_vrt + + return input_file + + +def reproject_if_needed(input_file: str, config: StandardisingConfig, tmp_path: str) -> str: + """Reproject the VRT file if source EPSG differs from target EPSG.""" + if config.source_epsg != config.target_epsg: + target_vrt = os.path.join(tmp_path, "reproject.vrt") + get_log().info("Reprojecting TIFF", path=input_file, sourceEPSG=config.source_epsg, targetEPSG=config.target_epsg) + run_gdal( + get_transform_srs_command(config.source_epsg, config.target_epsg), input_file=input_file, output_file=target_vrt + ) + return target_vrt + return input_file + + +def apply_gdal_transformation(input_file: str, config: StandardisingConfig, tmp_path: str, tile_name: str) -> str: + """Generate output using GDAL command.""" + target_file = os.path.join(tmp_path, f"{tile_name}.tiff") + + command = get_gdal_command(config.gdal_preset, epsg=int(config.target_epsg)) + command.extend(get_gdal_band_offset(input_file, gdal_info(input_file), config.gdal_preset)) + + # Specify the extent to get the right boundaries in case of the tiff got no data on its edges + output_bounds: Bounds = get_bounds_from_name(tile_name) + min_x = output_bounds.point.x + max_y = output_bounds.point.y + min_y = max_y - output_bounds.size.height + max_x = min_x + output_bounds.size.width + command.extend(["-co", f"TARGET_SRS=EPSG:{config.target_epsg}", "-co", f"EXTENT={min_x},{min_y},{max_x},{max_y}"]) + + get_log().info("Running GDAL", command=command, input_file=input_file, output_file=target_file) + + # Need GDAL to write to temporary location so no broken files end up in the done folder. + run_gdal(command, input_file=input_file, output_file=target_file) + + return target_file + + +def check_tiff_empty(current_working_file: str) -> bool: + """Validate if the TIFF output is empty.""" + with TiffFile(current_working_file) as file_handle: + return all(tile_byte_count == 0 for tile_byte_count in file_handle.pages.first.tags["TileByteCounts"].value) + + +def create_footprint(current_working_file: str, config: StandardisingConfig, tmp_path: str) -> str: + """Create the footprint from the standardized TIFF.""" + footprint_tmp_path = os.path.join(tmp_path, f"footprint{SUFFIX_FOOTPRINT}") + run_gdal( + [ + "gdal_footprint", + "-t_srs", + f"EPSG:{EpsgNumber.WGS_1984.value}", + "-max_points", + "unlimited", + "-lco", + "COORDINATE_PRECISION=8", + "-simplify", + str(get_buffer_distance(config.gsd)), + ], + current_working_file, + footprint_tmp_path, + ) + return footprint_tmp_path