From c8bb107d24837c06a6d25c7bb313677731683041 Mon Sep 17 00:00:00 2001
From: Even Rouault <even.rouault@spatialys.com>
Date: Mon, 8 Jan 2024 20:34:17 +0100
Subject: [PATCH] Add GDALGetOutputDriversForDatasetName()

should help implementing https://github.com/georust/gdal/pull/510
---
 apps/commonutils.cpp | 134 ++++--------------------------------
 gcore/gdal.h         |   4 ++
 gcore/gdaldriver.cpp | 159 +++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 178 insertions(+), 119 deletions(-)

diff --git a/apps/commonutils.cpp b/apps/commonutils.cpp
index f2e472e76b7a..b7df3fb6179a 100644
--- a/apps/commonutils.cpp
+++ b/apps/commonutils.cpp
@@ -37,31 +37,6 @@
 #include "cpl_string.h"
 #include "gdal.h"
 
-/* -------------------------------------------------------------------- */
-/*                   DoesDriverHandleExtension()                        */
-/* -------------------------------------------------------------------- */
-
-static bool DoesDriverHandleExtension(GDALDriverH hDriver, const char *pszExt)
-{
-    bool bRet = false;
-    const char *pszDriverExtensions =
-        GDALGetMetadataItem(hDriver, GDAL_DMD_EXTENSIONS, nullptr);
-    if (pszDriverExtensions)
-    {
-        char **papszTokens = CSLTokenizeString(pszDriverExtensions);
-        for (int j = 0; papszTokens[j]; j++)
-        {
-            if (EQUAL(pszExt, papszTokens[j]))
-            {
-                bRet = true;
-                break;
-            }
-        }
-        CSLDestroy(papszTokens);
-    }
-    return bRet;
-}
-
 /* -------------------------------------------------------------------- */
 /*                         GetOutputDriversFor()                        */
 /* -------------------------------------------------------------------- */
@@ -70,72 +45,12 @@ std::vector<CPLString> GetOutputDriversFor(const char *pszDestFilename,
                                            int nFlagRasterVector)
 {
     std::vector<CPLString> aoDriverList;
-
-    CPLString osExt = CPLGetExtension(pszDestFilename);
-    if (EQUAL(osExt, "zip") &&
-        (CPLString(pszDestFilename).endsWith(".shp.zip") ||
-         CPLString(pszDestFilename).endsWith(".SHP.ZIP")))
-    {
-        osExt = "shp.zip";
-    }
-    else if (EQUAL(osExt, "zip") &&
-             (CPLString(pszDestFilename).endsWith(".gpkg.zip") ||
-              CPLString(pszDestFilename).endsWith(".GPKG.ZIP")))
-    {
-        osExt = "gpkg.zip";
-    }
-    const int nDriverCount = GDALGetDriverCount();
-    for (int i = 0; i < nDriverCount; i++)
-    {
-        GDALDriverH hDriver = GDALGetDriver(i);
-        bool bOk = false;
-        if ((GDALGetMetadataItem(hDriver, GDAL_DCAP_CREATE, nullptr) !=
-                 nullptr ||
-             GDALGetMetadataItem(hDriver, GDAL_DCAP_CREATECOPY, nullptr) !=
-                 nullptr) &&
-            (((nFlagRasterVector & GDAL_OF_RASTER) &&
-              GDALGetMetadataItem(hDriver, GDAL_DCAP_RASTER, nullptr) !=
-                  nullptr) ||
-             ((nFlagRasterVector & GDAL_OF_VECTOR) &&
-              GDALGetMetadataItem(hDriver, GDAL_DCAP_VECTOR, nullptr) !=
-                  nullptr)))
-        {
-            bOk = true;
-        }
-        else if (GDALGetMetadataItem(hDriver, GDAL_DCAP_VECTOR_TRANSLATE_FROM,
-                                     nullptr) &&
-                 (nFlagRasterVector & GDAL_OF_VECTOR) != 0)
-        {
-            bOk = true;
-        }
-        if (bOk)
-        {
-            if (!osExt.empty() && DoesDriverHandleExtension(hDriver, osExt))
-            {
-                aoDriverList.push_back(GDALGetDriverShortName(hDriver));
-            }
-            else
-            {
-                const char *pszPrefix = GDALGetMetadataItem(
-                    hDriver, GDAL_DMD_CONNECTION_PREFIX, nullptr);
-                if (pszPrefix && STARTS_WITH_CI(pszDestFilename, pszPrefix))
-                {
-                    aoDriverList.push_back(GDALGetDriverShortName(hDriver));
-                }
-            }
-        }
-    }
-
-    // GMT is registered before netCDF for opening reasons, but we want
-    // netCDF to be used by default for output.
-    if (EQUAL(osExt, "nc") && aoDriverList.size() == 2 &&
-        EQUAL(aoDriverList[0], "GMT") && EQUAL(aoDriverList[1], "NETCDF"))
-    {
-        aoDriverList.clear();
-        aoDriverList.push_back("NETCDF");
-        aoDriverList.push_back("GMT");
-    }
-
+    char **papszList = GDALGetOutputDriversForDatasetName(
+        pszDestFilename, nFlagRasterVector, /* bSingleMatch = */ false,
+        /* bEmitWarning = */ false);
+    for (char **papszIter = papszList; papszIter && *papszIter; ++papszIter)
+        aoDriverList.push_back(*papszIter);
+    CSLDestroy(papszList);
     return aoDriverList;
 }
 
@@ -145,36 +60,17 @@ std::vector<CPLString> GetOutputDriversFor(const char *pszDestFilename,
 
 CPLString GetOutputDriverForRaster(const char *pszDestFilename)
 {
-    CPLString osFormat;
-    std::vector<CPLString> aoDrivers =
-        GetOutputDriversFor(pszDestFilename, GDAL_OF_RASTER);
-    CPLString osExt(CPLGetExtension(pszDestFilename));
-    if (aoDrivers.empty())
+    char **papszList = GDALGetOutputDriversForDatasetName(
+        pszDestFilename, GDAL_OF_RASTER, /* bSingleMatch = */ true,
+        /* bEmitWarning = */ true);
+    if (papszList)
     {
-        if (osExt.empty())
-        {
-            osFormat = "GTiff";
-        }
-        else
-        {
-            CPLError(CE_Failure, CPLE_AppDefined, "Cannot guess driver for %s",
-                     pszDestFilename);
-            return "";
-        }
-    }
-    else
-    {
-        if (aoDrivers.size() > 1 &&
-            !(aoDrivers[0] == "GTiff" && aoDrivers[1] == "COG"))
-        {
-            CPLError(CE_Warning, CPLE_AppDefined,
-                     "Several drivers matching %s extension. Using %s",
-                     osExt.c_str(), aoDrivers[0].c_str());
-        }
-        osFormat = aoDrivers[0];
+        CPLDebug("GDAL", "Using %s driver", papszList[0]);
+        const std::string osRet = papszList[0];
+        CSLDestroy(papszList);
+        return osRet;
     }
-    CPLDebug("GDAL", "Using %s driver", osFormat.c_str());
-    return osFormat;
+    return CPLString();
 }
 
 /* -------------------------------------------------------------------- */
diff --git a/gcore/gdal.h b/gcore/gdal.h
index 8ce27a5a2431..40452553d673 100644
--- a/gcore/gdal.h
+++ b/gcore/gdal.h
@@ -1040,6 +1040,10 @@ CPLErr CPL_DLL CPL_STDCALL GDALCopyDatasetFiles(GDALDriverH,
                                                 const char *pszOldName);
 int CPL_DLL CPL_STDCALL
 GDALValidateCreationOptions(GDALDriverH, CSLConstList papszCreationOptions);
+char CPL_DLL **GDALGetOutputDriversForDatasetName(const char *pszDestFilename,
+                                                  int nFlagRasterVector,
+                                                  bool bSingleMatch,
+                                                  bool bEmitWarning);
 
 /* The following are deprecated */
 const char CPL_DLL *CPL_STDCALL GDALGetDriverShortName(GDALDriverH);
diff --git a/gcore/gdaldriver.cpp b/gcore/gdaldriver.cpp
index 4cda88e7bc09..cbca9f575c3b 100644
--- a/gcore/gdaldriver.cpp
+++ b/gcore/gdaldriver.cpp
@@ -2810,3 +2810,162 @@ CPLErr GDALDriver::SetMetadataItem(const char *pszName, const char *pszValue,
     }
     return GDALMajorObject::SetMetadataItem(pszName, pszValue, pszDomain);
 }
+
+/************************************************************************/
+/*                   DoesDriverHandleExtension()                        */
+/************************************************************************/
+
+static bool DoesDriverHandleExtension(GDALDriverH hDriver, const char *pszExt)
+{
+    bool bRet = false;
+    const char *pszDriverExtensions =
+        GDALGetMetadataItem(hDriver, GDAL_DMD_EXTENSIONS, nullptr);
+    if (pszDriverExtensions)
+    {
+        const CPLStringList aosTokens(CSLTokenizeString(pszDriverExtensions));
+        const int nTokens = aosTokens.size();
+        for (int j = 0; j < nTokens; ++j)
+        {
+            if (EQUAL(pszExt, aosTokens[j]))
+            {
+                bRet = true;
+                break;
+            }
+        }
+    }
+    return bRet;
+}
+
+/************************************************************************/
+/*                  GDALGetOutputDriversForDatasetName()                */
+/************************************************************************/
+
+/** Return a list of driver short names that are likely candidates for the
+ * provided output file name.
+ *
+ * @param pszDestDataset Output dataset name (might not exist).
+ * @param nFlagRasterVector GDAL_OF_RASTER, GDAL_OF_VECTOR or
+ *                          binary-or'ed combination of both
+ * @param bSingleMatch Whether a single match is desired. In this mode, if
+ *                     nFlagRasterVector==GDAL_OF_RASTER and pszDestDataset has
+ *                     no extension, GTiff will be selected.
+ * @param bEmitWarning Whether a warning should be emitted when bSingleMatch is
+ *                     true and there are more than 2 candidates.
+ * @return NULL terminated list of driver short names.
+ * To be freed with CSLDestroy()
+ * @since 3.9
+ */
+char **GDALGetOutputDriversForDatasetName(const char *pszDestDataset,
+                                          int nFlagRasterVector,
+                                          bool bSingleMatch, bool bEmitWarning)
+{
+    CPLStringList aosDriverNames;
+
+    std::string osExt = CPLGetExtension(pszDestDataset);
+    if (EQUAL(osExt.c_str(), "zip"))
+    {
+        const CPLString osLower(CPLString(pszDestDataset).tolower());
+        if (osLower.endsWith(".shp.zip"))
+        {
+            osExt = "shp.zip";
+        }
+        else if (osLower.endsWith(".gpkg.zip"))
+        {
+            osExt = "gpkg.zip";
+        }
+    }
+
+    const int nDriverCount = GDALGetDriverCount();
+    for (int i = 0; i < nDriverCount; i++)
+    {
+        GDALDriverH hDriver = GDALGetDriver(i);
+        bool bOk = false;
+        if ((GDALGetMetadataItem(hDriver, GDAL_DCAP_CREATE, nullptr) !=
+                 nullptr ||
+             GDALGetMetadataItem(hDriver, GDAL_DCAP_CREATECOPY, nullptr) !=
+                 nullptr) &&
+            (((nFlagRasterVector & GDAL_OF_RASTER) &&
+              GDALGetMetadataItem(hDriver, GDAL_DCAP_RASTER, nullptr) !=
+                  nullptr) ||
+             ((nFlagRasterVector & GDAL_OF_VECTOR) &&
+              GDALGetMetadataItem(hDriver, GDAL_DCAP_VECTOR, nullptr) !=
+                  nullptr)))
+        {
+            bOk = true;
+        }
+        else if (GDALGetMetadataItem(hDriver, GDAL_DCAP_VECTOR_TRANSLATE_FROM,
+                                     nullptr) &&
+                 (nFlagRasterVector & GDAL_OF_VECTOR) != 0)
+        {
+            bOk = true;
+        }
+        if (bOk)
+        {
+            if (!osExt.empty() &&
+                DoesDriverHandleExtension(hDriver, osExt.c_str()))
+            {
+                aosDriverNames.AddString(GDALGetDriverShortName(hDriver));
+            }
+            else
+            {
+                const char *pszPrefix = GDALGetMetadataItem(
+                    hDriver, GDAL_DMD_CONNECTION_PREFIX, nullptr);
+                if (pszPrefix && STARTS_WITH_CI(pszDestDataset, pszPrefix))
+                {
+                    aosDriverNames.AddString(GDALGetDriverShortName(hDriver));
+                }
+            }
+        }
+    }
+
+    // GMT is registered before netCDF for opening reasons, but we want
+    // netCDF to be used by default for output.
+    if (EQUAL(osExt.c_str(), "nc") && aosDriverNames.size() == 2 &&
+        EQUAL(aosDriverNames[0], "GMT") && EQUAL(aosDriverNames[1], "netCDF"))
+    {
+        aosDriverNames.Clear();
+        aosDriverNames.AddString("netCDF");
+        aosDriverNames.AddString("GMT");
+    }
+
+    if (bSingleMatch)
+    {
+        if (nFlagRasterVector == GDAL_OF_RASTER)
+        {
+            if (aosDriverNames.empty())
+            {
+                if (osExt.empty())
+                {
+                    aosDriverNames.AddString("GTiff");
+                }
+            }
+            else if (aosDriverNames.size() >= 2)
+            {
+                if (bEmitWarning && !(EQUAL(aosDriverNames[0], "GTiff") &&
+                                      EQUAL(aosDriverNames[1], "COG")))
+                {
+                    CPLError(CE_Warning, CPLE_AppDefined,
+                             "Several drivers matching %s extension. Using %s",
+                             osExt.c_str(), aosDriverNames[0]);
+                }
+                const std::string osDrvName = aosDriverNames[0];
+                aosDriverNames.Clear();
+                aosDriverNames.AddString(osDrvName.c_str());
+            }
+        }
+        else if (aosDriverNames.size() >= 2)
+        {
+            if (bEmitWarning)
+            {
+                CPLError(CE_Warning, CPLE_AppDefined,
+                         "Several drivers matching %s extension. Using %s",
+                         osExt.c_str(), aosDriverNames[0]);
+            }
+            const std::string osDrvName = aosDriverNames[0];
+            aosDriverNames.Clear();
+            aosDriverNames.AddString(osDrvName.c_str());
+        }
+    }
+
+    return aosDriverNames.StealList();
+}