diff --git a/CHANGES.md b/CHANGES.md index db8859a4..0229f762 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -4,6 +4,8 @@ ### Improvements +- Add support for `fids` filter to `read_arrow` and `open_arrow`, and to + `read_dataframe` with `use_arrow=True` (#304). - Add some missing properties to `read_info`, including layer name, geometry name and FID column name (#365). - `read_arrow` and `open_arrow` now provide diff --git a/pyogrio/_io.pyx b/pyogrio/_io.pyx index 11a0663e..d8b55978 100644 --- a/pyogrio/_io.pyx +++ b/pyogrio/_io.pyx @@ -1333,7 +1333,11 @@ def ogr_open_arrow( raise ValueError("forcing 2D is not supported for Arrow") if fids is not None: - raise ValueError("reading by FID is not supported for Arrow") + if where is not None or bbox is not None or mask is not None or sql is not None or skip_features or max_features: + raise ValueError( + "cannot set both 'fids' and any of 'where', 'bbox', 'mask', " + "'sql', 'skip_features', or 'max_features'" + ) IF CTE_GDAL_VERSION < (3, 8, 0): if skip_features: @@ -1407,14 +1411,45 @@ def ogr_open_arrow( geometry_name = get_string(OGR_L_GetGeometryColumn(ogr_layer)) fid_column = get_string(OGR_L_GetFIDColumn(ogr_layer)) + fid_column_where = fid_column # OGR_L_GetFIDColumn returns the column name if it is a custom column, - # or "" if not. For arrow, the default column name is "OGC_FID". + # or "" if not. For arrow, the default column name used to return the FID data + # read is "OGC_FID". When accessing the underlying datasource like when using a + # where clause, the default column name is "FID". if fid_column == "": fid_column = "OGC_FID" + fid_column_where = "FID" + + # Use fids list to create a where clause, as arrow doesn't support direct fid + # filtering. + if fids is not None: + IF CTE_GDAL_VERSION < (3, 8, 0): + driver = get_driver(ogr_dataset) + if driver not in {"GPKG", "GeoJSON"}: + warnings.warn( + "Using 'fids' and 'use_arrow=True' with GDAL < 3.8 can be slow " + "for some drivers. Upgrading GDAL or using 'use_arrow=False' " + "can avoid this.", + stacklevel=2, + ) + + fids_str = ",".join([str(fid) for fid in fids]) + where = f"{fid_column_where} IN ({fids_str})" # Apply the attribute filter if where is not None and where != "": - apply_where_filter(ogr_layer, where) + try: + apply_where_filter(ogr_layer, where) + except ValueError as ex: + if fids is not None and str(ex).startswith("Invalid SQL query"): + # If fids is not None, the where being applied is the one formatted + # above. + raise ValueError( + f"error applying filter for {len(fids)} fids; max. number for " + f"drivers with default SQL dialect 'OGRSQL' is 4997" + ) from ex + + raise # Apply the spatial filter if bbox is not None: diff --git a/pyogrio/geopandas.py b/pyogrio/geopandas.py index 62015d6c..db762e31 100644 --- a/pyogrio/geopandas.py +++ b/pyogrio/geopandas.py @@ -160,7 +160,12 @@ def read_dataframe( the starting index is driver and file specific (e.g. typically 0 for Shapefile and 1 for GeoPackage, but can still depend on the specific file). The performance of reading a large number of features usings FIDs - is also driver specific. + is also driver specific and depends on the value of ``use_arrow``. The order + of the rows returned is undefined. If you would like to sort based on FID, use + ``fid_as_index=True`` to have the index of the GeoDataFrame returned set to the + FIDs of the features read. If ``use_arrow=True``, the number of FIDs is limited + to 4997 for drivers with 'OGRSQL' as default SQL dialect. To read a larger + number of FIDs, set ``user_arrow=False``. sql : str, optional (default: None) The SQL statement to execute. Look at the sql_dialect parameter for more information on the syntax to use for the query. When combined with other @@ -345,7 +350,7 @@ def write_dataframe( in the output file. path : str path to file - layer :str, optional (default: None) + layer : str, optional (default: None) layer name driver : string, optional (default: None) The OGR format driver used to write the vector file. By default write_dataframe @@ -545,7 +550,7 @@ def write_dataframe( # if possible use EPSG codes instead epsg = geometry.crs.to_epsg() if epsg: - crs = f"EPSG:{epsg}" + crs = f"EPSG:{epsg}" # noqa: E231 else: crs = geometry.crs.to_wkt(WktVersion.WKT1_GDAL) diff --git a/pyogrio/tests/test_geopandas_io.py b/pyogrio/tests/test_geopandas_io.py index b5ac2a82..4bc90dd9 100644 --- a/pyogrio/tests/test_geopandas_io.py +++ b/pyogrio/tests/test_geopandas_io.py @@ -518,14 +518,45 @@ def test_read_mask_where(naturalearth_lowres_all_ext, use_arrow): assert np.array_equal(df.iso_a3, ["CAN"]) -def test_read_fids(naturalearth_lowres_all_ext): +@pytest.mark.parametrize("fids", [[1, 5, 10], np.array([1, 5, 10], dtype=np.int64)]) +def test_read_fids(naturalearth_lowres_all_ext, fids, use_arrow): # ensure keyword is properly passed through - fids = np.array([1, 10, 5], dtype=np.int64) - df = read_dataframe(naturalearth_lowres_all_ext, fids=fids, fid_as_index=True) + df = read_dataframe( + naturalearth_lowres_all_ext, fids=fids, fid_as_index=True, use_arrow=use_arrow + ) assert len(df) == 3 assert np.array_equal(fids, df.index.values) +@requires_pyarrow_api +def test_read_fids_arrow_max_exception(naturalearth_lowres): + # Maximum number at time of writing is 4997 for "OGRSQL". For e.g. for SQLite based + # formats like Geopackage, there is no limit. + nb_fids = 4998 + fids = range(nb_fids) + with pytest.raises(ValueError, match=f"error applying filter for {nb_fids} fids"): + _ = read_dataframe(naturalearth_lowres, fids=fids, use_arrow=True) + + +@requires_pyarrow_api +@pytest.mark.skipif( + __gdal_version__ >= (3, 8, 0), reason="GDAL >= 3.8.0 does not need to warn" +) +def test_read_fids_arrow_warning_old_gdal(naturalearth_lowres_all_ext): + # A warning should be given for old GDAL versions, except for some file formats. + if naturalearth_lowres_all_ext.suffix not in [".gpkg", ".geojson"]: + handler = pytest.warns( + UserWarning, + match="Using 'fids' and 'use_arrow=True' with GDAL < 3.8 can be slow", + ) + else: + handler = contextlib.nullcontext() + + with handler: + df = read_dataframe(naturalearth_lowres_all_ext, fids=[22], use_arrow=True) + assert len(df) == 1 + + def test_read_fids_force_2d(test_fgdb_vsi): with pytest.warns( UserWarning, match=r"Measured \(M\) geometry types are not supported"