forked from stevenpawley/r.learn.ml2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
135 lines (98 loc) · 3.22 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from grass.pygrass.utils import set_path
set_path('r.learn.ml')
from raster import RasterStack
stack = RasterStack(rasters=["lsat5_1987_10", "lsat5_1987_20", "lsat5_1987_30", "lsat5_1987_40",
"lsat5_1987_50", "lsat5_1987_70"])
stack = RasterStack(rasters=maplist)
stack.lsat5_1987_10
maplist2 = deepcopy(maplist)
maplist2 = [i.split('@')[0] for i in maplist2]
stack = RasterStack(rasters=maplist2)
stack.lsat5_1987_10
X, y, crd = stack.extract_points(vect_name='landclass96_roi', fields=['value', 'cat'])
df = stack.extract_points(vect_name='landclass96_roi', field='value', as_df=True)
df = stack.extract_pixels(response='landclass96_roi', as_df=True)
X, y, crd = stack.extract_pixels(response='landclass96_roi')
stack.head()
stack.tail()
data = stack.read()
data.shape
df = stack.to_pandas()
# df = stack.to_pandas(res=500)
df = df.melt(id_vars=['x', 'y'])
from plotnine import *
(ggplot(df, aes(x="x", y="y", fill="value")) +
geom_tile() +
coord_fixed() +
facet_wrap("variable") +
theme_light() +
theme(axis_title = element_blank()))
from sklearn.ensemble import RandomForestClassifier
clf = RandomForestClassifier(n_estimators=100)
clf.fit(X, y)
stack.predict(clf, output='test', overwrite=True, height=25)
stack.predict_proba(clf, output='test', overwrite=True, height=25)
test = RasterRow('test')
from grass.pygrass.modules.shortcuts import raster as r
r.colors('test', color='random')
test
test.close()
from sklearn.model_selection import cross_validate
cross_validate(clf, X, y, cv=3)
from grass.pygrass.gis.region import Region
from grass.pygrass.modules.grid.grid import GridModule
from grass.pygrass.modules.grid import split
from grass.pygrass.modules.shortcuts import general as g
from grass.pygrass.raster import RasterRow
import multiprocessing as mltp
from itertools import chain
import time
import numpy as np
reg = Region()
# profile reading region-based blocks
# testreg = GridModule('g.region', width=100, height=100, processes=4)
testreg = split.split_region_tiles(width=reg.cols, height=100)
def worker(src):
window, src = src
window = dict(window)
window['n'] = window.pop('north')
window['s'] = window.pop('south')
window['e'] = window.pop('east')
window['w'] = window.pop('west')
del(window['top'])
del(window['bottom'])
g.region(**window)
with RasterRow(src) as rs:
arr = np.asarray(rs)
return(arr)
windows = list(chain.from_iterable(testreg))
windows = [[i.items(), "lsat5_1987_10"] for i in windows]
start = time.time()
pool = mltp.Pool(processes=8)
arrs = pool.map(func=worker, iterable=windows)
end = time.time()
print(end - start)
# profile reading single thread per row
start = time.time()
with RasterRow("lsat5_1987_10") as src:
arr = []
for i in range(reg.rows):
arr.append(src[i])
end = time.time()
print(end - start)
# profile reading multiprocessing per row
def worker(src):
row, src = src
with RasterRow(src) as rs:
arr = rs[row]
return row
rows = [(i, "lsat5_1987_10") for i in range(reg.rows)]
start = time.time()
pool = mltp.Pool(processes=8)
arrs = pool.map(func=worker, iterable=rows)
end = time.time()
print(end - start)
src = RasterRow("lsat5_1987_10")
src.open()
src[0]
src.close()