diff --git a/autotest/gcore/thread_test.py b/autotest/gcore/thread_test.py index f57281c00fc2..91e8d574b202 100755 --- a/autotest/gcore/thread_test.py +++ b/autotest/gcore/thread_test.py @@ -88,7 +88,7 @@ def verify_checksum(): res[0] = False assert False, (got_cs, expected_cs) - threads = [threading.Thread(target=verify_checksum)] + threads = [threading.Thread(target=verify_checksum) for i in range(2)] for t in threads: t.start() for t in threads: @@ -426,3 +426,190 @@ def test_thread_safe_unsupported_rat(): match="not supporting a non-GDALDefaultRasterAttributeTable implementation", ): ds.GetRasterBand(1).GetDefaultRAT() + + +def test_thread_safe_many_datasets(): + + tab_ds = [ + gdal.OpenEx( + "data/byte.tif" if (i % 3) < 2 else "data/utmsmall.tif", + gdal.OF_RASTER | gdal.OF_THREAD_SAFE, + ) + for i in range(100) + ] + + res = [True] + + def check(): + for _ in range(10): + for i, ds in enumerate(tab_ds): + if ds.GetRasterBand(1).Checksum() != (4672 if (i % 3) < 2 else 50054): + res[0] = False + + threads = [threading.Thread(target=check) for i in range(2)] + for t in threads: + t.start() + for t in threads: + t.join() + assert res[0] + + +def test_thread_safe_BeginAsyncReader(): + + with gdal.OpenEx("data/byte.tif", gdal.OF_RASTER | gdal.OF_THREAD_SAFE) as ds: + with pytest.raises(Exception, match="not supported"): + ds.BeginAsyncReader(0, 0, ds.RasterXSize, ds.RasterYSize) + + +def test_thread_safe_GetVirtualMem(): + + pytest.importorskip("numpy") + pytest.importorskip("osgeo.gdal_array") + + with gdal.OpenEx("data/byte.tif", gdal.OF_RASTER | gdal.OF_THREAD_SAFE) as ds: + with pytest.raises(Exception, match="not supported"): + ds.GetRasterBand(1).GetVirtualMemAutoArray(gdal.GF_Read) + + +def test_thread_safe_GetMetadadata(tmp_vsimem): + + filename = str(tmp_vsimem / "test.tif") + with gdal.GetDriverByName("GTiff").Create(filename, 1, 1) as ds: + ds.SetMetadataItem("foo", "bar") + ds.GetRasterBand(1).SetMetadataItem("bar", "baz") + + with gdal.OpenEx(filename, gdal.OF_RASTER | gdal.OF_THREAD_SAFE) as ds: + assert ds.GetMetadataItem("foo") == "bar" + assert ds.GetMetadataItem("not existing") is None + assert ds.GetMetadata() == {"foo": "bar"} + assert ds.GetMetadata("not existing") == {} + assert ds.GetRasterBand(1).GetMetadataItem("bar") == "baz" + assert ds.GetRasterBand(1).GetMetadataItem("not existing") is None + assert ds.GetRasterBand(1).GetMetadata() == {"bar": "baz"} + assert ds.GetRasterBand(1).GetMetadata("not existing") == {} + + +def test_thread_safe_GetUnitType(tmp_vsimem): + + filename = str(tmp_vsimem / "test.tif") + with gdal.GetDriverByName("GTiff").Create(filename, 1, 1) as ds: + ds.GetRasterBand(1).SetUnitType("foo") + + with gdal.OpenEx(filename, gdal.OF_RASTER | gdal.OF_THREAD_SAFE) as ds: + assert ds.GetRasterBand(1).GetUnitType() == "foo" + + +def test_thread_safe_GetColorTable(tmp_vsimem): + + filename = str(tmp_vsimem / "test.tif") + with gdal.GetDriverByName("GTiff").Create(filename, 1, 1) as ds: + ct = gdal.ColorTable() + ct.SetColorEntry(0, (1, 2, 3, 255)) + ds.GetRasterBand(1).SetColorTable(ct) + + with gdal.OpenEx(filename, gdal.OF_RASTER | gdal.OF_THREAD_SAFE) as ds: + res = [None] + + def thread_job(): + res[0] = ds.GetRasterBand(1).GetColorTable() + + t = threading.Thread(target=thread_job) + t.start() + t.join() + assert res[0] + assert res[0].GetColorEntry(0) == (1, 2, 3, 255) + ct = ds.GetRasterBand(1).GetColorTable() + assert ct.GetColorEntry(0) == (1, 2, 3, 255) + + +def test_thread_safe_GetSpatialRef(): + + with gdal.OpenEx("data/byte.tif", gdal.OF_RASTER | gdal.OF_THREAD_SAFE) as ds: + + res = [True] + + def check(): + for i in range(100): + + if len(ds.GetGCPs()) != 0: + res[0] = False + assert False + + if ds.GetGCPSpatialRef(): + res[0] = False + assert False + + if ds.GetGCPProjection(): + res[0] = False + assert False + + srs = ds.GetSpatialRef() + if not srs: + res[0] = False + assert False + if not srs.IsProjected(): + res[0] = False + assert False + if "NAD27 / UTM zone 11N" not in srs.ExportToWkt(): + res[0] = False + assert False + + if "NAD27 / UTM zone 11N" not in ds.GetProjectionRef(): + res[0] = False + assert False + + threads = [threading.Thread(target=check) for i in range(2)] + for t in threads: + t.start() + for t in threads: + t.join() + assert res[0] + + +def test_thread_safe_GetGCPs(): + + with gdal.OpenEx( + "data/byte_gcp_pixelispoint.tif", gdal.OF_RASTER | gdal.OF_THREAD_SAFE + ) as ds: + + res = [True] + + def check(): + for i in range(100): + + if len(ds.GetGCPs()) != 4: + res[0] = False + assert False + + gcp_srs = ds.GetGCPSpatialRef() + if gcp_srs is None: + res[0] = False + assert False + if not gcp_srs.IsGeographic(): + res[0] = False + assert False + if "unretrievable - using WGS84" not in gcp_srs.ExportToWkt(): + res[0] = False + assert False + + gcp_wkt = ds.GetGCPProjection() + if not gcp_wkt: + res[0] = False + assert False + if "unretrievable - using WGS84" not in gcp_wkt: + res[0] = False + assert False + + if ds.GetSpatialRef(): + res[0] = False + assert False + if ds.GetProjectionRef() != "": + res[0] = False + assert False + + threads = [threading.Thread(target=check) for i in range(2)] + for t in threads: + t.start() + for t in threads: + t.join() + assert res[0] diff --git a/gcore/gdalthreadsafedataset.cpp b/gcore/gdalthreadsafedataset.cpp index 3b2c8156f249..2ce0de26d31c 100644 --- a/gcore/gdalthreadsafedataset.cpp +++ b/gcore/gdalthreadsafedataset.cpp @@ -177,6 +177,71 @@ class GDALThreadSafeDataset final : public GDALProxyDataset static GDALDataset *Create(GDALDataset *poPrototypeDS, int nScopeFlags); + /* All below public methods override GDALDataset methods, and instead of + * forwarding to a thread-local dataset, they act on the prototype dataset, + * because they return a non-trivial type, that could be invalidated + * otherwise if the thread-local dataset is evicted from the LRU cache. + */ + const OGRSpatialReference *GetSpatialRef() const override + { + std::lock_guard oGuard(m_oPrototypeDSMutex); + if (m_oSRS.IsEmpty()) + { + auto poSRS = m_poPrototypeDS->GetSpatialRef(); + if (poSRS) + { + m_oSRS.AssignAndSetThreadSafe(*poSRS); + } + } + return m_oSRS.IsEmpty() ? nullptr : &m_oSRS; + } + + const OGRSpatialReference *GetGCPSpatialRef() const override + { + std::lock_guard oGuard(m_oPrototypeDSMutex); + if (m_oGCPSRS.IsEmpty()) + { + auto poSRS = m_poPrototypeDS->GetGCPSpatialRef(); + if (poSRS) + { + m_oGCPSRS.AssignAndSetThreadSafe(*poSRS); + } + } + return m_oGCPSRS.IsEmpty() ? nullptr : &m_oGCPSRS; + } + + const GDAL_GCP *GetGCPs() override + { + std::lock_guard oGuard(m_oPrototypeDSMutex); + return const_cast(m_poPrototypeDS)->GetGCPs(); + } + + const char *GetMetadataItem(const char *pszName, + const char *pszDomain = "") override + { + std::lock_guard oGuard(m_oPrototypeDSMutex); + return const_cast(m_poPrototypeDS) + ->GetMetadataItem(pszName, pszDomain); + } + + char **GetMetadata(const char *pszDomain = "") override + { + std::lock_guard oGuard(m_oPrototypeDSMutex); + return const_cast(m_poPrototypeDS) + ->GetMetadata(pszDomain); + } + + /* End of methods that forward on the prototype dataset */ + + GDALAsyncReader *BeginAsyncReader(int, int, int, int, void *, int, int, + GDALDataType, int, int *, int, int, int, + char **) override + { + CPLError(CE_Failure, CPLE_AppDefined, + "GDALThreadSafeDataset::BeginAsyncReader() not supported"); + return nullptr; + } + protected: GDALDataset *RefUnderlyingDataset() const override; @@ -190,7 +255,7 @@ class GDALThreadSafeDataset final : public GDALProxyDataset friend class GDALThreadLocalDatasetCache; /** Mutex that protects accesses to m_poPrototypeDS */ - std::mutex m_oPrototypeDSMutex{}; + mutable std::mutex m_oPrototypeDSMutex{}; /** "Prototype" dataset, that is the dataset that was passed to the * GDALThreadSafeDataset constructor. All calls on to it should be on @@ -208,6 +273,12 @@ class GDALThreadSafeDataset final : public GDALProxyDataset */ const CPLStringList m_aosThreadLocalConfigOptions{}; + /** Cached value returned by GetSpatialRef() */ + mutable OGRSpatialReference m_oSRS{}; + + /** Cached value returned by GetGCPSpatialRef() */ + mutable OGRSpatialReference m_oGCPSRS{}; + /** Structure that references all GDALThreadLocalDatasetCache* instances. */ struct GlobalCache @@ -275,6 +346,48 @@ class GDALThreadSafeRasterBand final : public GDALProxyRasterBand GDALRasterAttributeTable *GetDefaultRAT() override; + /* All below public methods override GDALRasterBand methods, and instead of + * forwarding to a thread-local dataset, they act on the prototype band, + * because they return a non-trivial type, that could be invalidated + * otherwise if the thread-local dataset is evicted from the LRU cache. + */ + const char *GetMetadataItem(const char *pszName, + const char *pszDomain = "") override + { + std::lock_guard oGuard(m_poTSDS->m_oPrototypeDSMutex); + return const_cast(m_poPrototypeBand) + ->GetMetadataItem(pszName, pszDomain); + } + + char **GetMetadata(const char *pszDomain = "") override + { + std::lock_guard oGuard(m_poTSDS->m_oPrototypeDSMutex); + return const_cast(m_poPrototypeBand) + ->GetMetadata(pszDomain); + } + + const char *GetUnitType() override + { + std::lock_guard oGuard(m_poTSDS->m_oPrototypeDSMutex); + return const_cast(m_poPrototypeBand)->GetUnitType(); + } + + GDALColorTable *GetColorTable() override + { + std::lock_guard oGuard(m_poTSDS->m_oPrototypeDSMutex); + return const_cast(m_poPrototypeBand)->GetColorTable(); + } + + /* End of methods that forward on the prototype band */ + + CPLVirtualMem *GetVirtualMemAuto(GDALRWFlag, int *, GIntBig *, + char **) override + { + CPLError(CE_Failure, CPLE_AppDefined, + "GDALThreadSafeRasterBand::GetVirtualMemAuto() not supported"); + return nullptr; + } + protected: GDALRasterBand *RefUnderlyingRasterBand(bool bForceOpen) const override; void UnrefUnderlyingRasterBand(