Skip to content

Commit

Permalink
added parralel extract regions/ sample regions
Browse files Browse the repository at this point in the history
  • Loading branch information
Geethen committed Jul 12, 2022
1 parent 29e0943 commit 3ed1c9b
Showing 1 changed file with 84 additions and 2 deletions.
86 changes: 84 additions & 2 deletions geeml/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,13 @@ def downloadPoints(item):
executor.shutdown(wait=False, cancel_futures=True)
raise ex

def extractRegions(self, reduce = True, reducers = None, gridSize = 50000, batchSize = None, filename = 'output.csv'):
def extractRegions(self, reduce = True, reducer = None, gridSize = 50000, batchSize = None, filename = 'output.csv'):
"""
Extract summary statistics of covariates for regions.
Args:
reduce (bool): default True. if False, each pixel within a polygon is downloaded.
reducers (ee.Reducer): The reducers to use to summarise data
reducer (ee.Reducer): The reducer(s) to use to summarise data. If multiple reducers need to be applied, use combined reducers.
gridSize (int): The tile size used to filter features. Runs in parralel.
batchSize (int): The number of batches to split job into. If large gridSize results in Out of Memory errors
specify a batchSize smaller than the number of samples. Runs in sequence.
Expand All @@ -258,4 +258,86 @@ def extractRegions(self, reduce = True, reducers = None, gridSize = 50000, batch
Data (csv) exported to download directory (dd).
"""

logger = logging.getLogger(__name__)

max_threads = self.num_threads or min(32, (os.cpu_count() or 1) + 4)

#Set working directory
if not os.path.exists(self.dd):
os.makedirs(self.dd)
os.chdir(self.dd)

self._properties = self.covariates.bandNames()
self.properties = self._properties.getInfo()

# add target band name to properties
if self.target.name == 'ee.Image':
self.properties = self.properties.add(self.target.bandNames()).getInfo()

# Create grid
grid, items = createGrid(gridSize, ee.Feature(self.aoi))

self.batchSize = batchSize

desc = 'Polygons'
bar_format = ('{desc}: |{bar}| [{percentage:5.1f}%] in {elapsed:>5s} (eta: {remaining:>5s})')
bar = tqdm(total = grid.size().getInfo(), desc=desc, bar_format=bar_format, dynamic_ncols=True, unit_scale=True, unit='B')

warnings.filterwarnings('ignore', category=TqdmWarning)
redir_tqdm = logging_redirect_tqdm([logging.getLogger(__package__)]) # redirect logging through tqdm

with redir_tqdm, bar:
def downloadPolygons(item):

polygons = self.geomPoints(grid, item)

size = polygons.size().getInfo()
if size>0:
polygonsList = polygons.toList(size)

for batch in range(0, size+1, self.batchSize):
fc = ee.FeatureCollection(polygonsList.slice(batch, batch+batchSize))

if reduce:
data = self.covariates.reduceRegions(fc, reducer, self.scale)
else:
data = self.covariates.sampleRegions(fc, scale = self.scale)

self._properties = data.first().propertyNames()
self.properties = self._properties.getInfo()

output = data.map(lambda ft: ft.set('output', self._properties.map(lambda prop: ft.get(prop))))
result = output.aggregate_array('output').getInfo()

file_exists = os.path.isfile(filename)
# Write the results to a file.
csv_writer_lock = threading.Lock()
with csv_writer_lock:
with open(filename, 'a', newline='') as f:
writer = csv.writer(f)
if not file_exists:
# write the header
writer.writerow(self.properties)
# write multiple rows
writer.writerows(result)
f.flush()
f.close()
else:
writer.writerows(result)
f.flush()
f.close()

with ThreadPoolExecutor(max_workers = max_threads) as executor:
# Run the tile downloads in a thread pool
futures = [executor.submit(downloadPolygons, tile) for tile in items]
try:
for future in as_completed(futures):
future.result()
bar.update(1)

except Exception as ex:
logger.info('Cancelling...')
executor.shutdown(wait=False, cancel_futures=True)
raise ex

0 comments on commit 3ed1c9b

Please sign in to comment.