diff --git a/Environment Setup_Tutorial.pdf b/Environment Setup_Tutorial.pdf index 093e798..e5c19c3 100644 Binary files a/Environment Setup_Tutorial.pdf and b/Environment Setup_Tutorial.pdf differ diff --git a/README.md b/README.md index 51041b1..d83937e 100644 --- a/README.md +++ b/README.md @@ -1,141 +1,90 @@ -This code contains deep learning code used to modeling hydrologic systems, from soil moisture to streamflow, from projection to forecast. - -# Citations - -If you find our code to be useful, please cite the following papers: - -Feng, DP, K. Fang and CP. Shen, [Enhancing streamflow forecast and extracting insights using continental-scale long-short term memory networks with data integration], Water Resources Reserach (2020), https://doi.org/10.1029/2019WR026793 - -Fang, K., CP. Shen, D. Kifer and X. Yang, [Prolongation of SMAP to Spatio-temporally Seamless Coverage of Continental US Using a Deep Learning Neural Network], Geophysical Research Letters, doi: 10.1002/2017GL075619, preprint accessible at: arXiv:1707.06611 (2017) https://agupubs.onlinelibrary.wiley.com/doi/full/10.1002/2017GL075619 - -Shen, CP., [A trans-disciplinary review of deep learning research and its relevance for water resources scientists], Water Resources Research. 54(11), 8558-8593, doi: 10.1029/2018WR022643 (2018) https://doi.org/10.1029/2018WR022643 - -Major code contributor: Kuai Fang (PhD., Penn State), and smaller contribution from Dapeng Feng (PhD Student, Penn State) - -A new release is expected in early July, 2020, together with video code walkthrough. -Computational benchmark: training of CAMELS data (w/ or w/o data integration) with 671 basins, 10 years, 300 epochs, in ~1 hour with GPU. - -# Example -Two examples with sample data are wrapped up including - - [train a LSTM network to learn SMAP soil moisture](example/train-lstm.py) - - [estimate uncertainty of a LSTM network ](example/train-lstm-mca.py) - -A demo for temporal test is [here](example/demo-temporal-test.ipynb) - -# License -Non-Commercial Software License Agreement - -By downloading the hydroDL software (the “Software”) you agree to -the following terms of use: -Copyright (c) 2020, The Pennsylvania State University (“PSU”). All rights reserved. - -1. PSU hereby grants to you a perpetual, nonexclusive and worldwide right, privilege and -license to use, reproduce, modify, display, and create derivative works of Software for all -non-commercial purposes only. You may not use Software for commercial purposes without -prior written consent from PSU. Queries regarding commercial licensing should be directed -to The Office of Technology Management at 814.865.6277 or otminfo@psu.edu. -2. Neither the name of the copyright holder nor the names of its contributors may be used -to endorse or promote products derived from this software without specific prior written -permission. -3. This software is provided for non-commercial use only. -4. Redistribution and use in source and binary forms, with or without modification, are -permitted provided that redistributions must reproduce the above copyright notice, license, -list of conditions and the following disclaimer in the documentation and/or other materials -provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. - - -# Database description -## Database Structure -``` -├── CONUS -│   ├── 2000 -│   │   ├── [Variable-Name].csv -│   │   ├── ... -│   │   ├── timeStr.csv -│   │   └── time.csv -│   ├── ... -│   ├── 2017 -│   │   └── ... -│   ├── const -│   │   ├── [Constant-Variable-Name].csv -│   │   └── ... -│   └── crd.csv -├── CONUSv4f1 -│   └── ... -├── Statistics -│   ├── [Variable-Name]_stat.csv -│   ├── ... -│   ├── const_[Constant-Variable-Name]_stat.csv -│   └── ... -├── Subset -│   ├── CONUS.csv -│   └── CONUSv4f1.csv -└── Variable - ├── varConstLst.csv - └── varLst.csv -``` -### 1. Dataset folders (*CONUS* , *CONUSv4f1*) -Data folder contains all data including both training and testing, time-dependent variables and constant variables. -In example data structure, there are two dataset folders - *CONUS* and *CONUSv4f1*. Those data are saved in: - - - **year/[Variable-Name].csv**: - -A csv file of size [#grid, #time], where each column is one grid and each row is one time step. This file saved data of a time-dependent variable of current year. For example, *CONUS/2010/SMAP_AM.csv* is SMAP data of 2002 on the CONUS. - -Most time-dependent varibles comes from NLDAS, which included two forcing product (FORA, FORB) and three simulations product land surface models (NOAH, MOS, VIC). Variables are named as *[variable]\_[product]\_[layer]*, and reference of variable can be found in [NLDAS document](https://hydro1.gesdisc.eosdis.nasa.gov/data/NLDAS/README.NLDAS2.pdf). For example, *SOILM_NOAH_0-10* refers to soil moisture product simulated by NOAH model at 0-10 cm. - -Other than NLDAS, SMAP data are also saved in same format but always used as target. In level 3 database, there are two SMAP csv files which are only available after 2015: *SMAP_AM.csv* and *SMAP_PM.csv*. - --9999 refers to NaN. - -- **year/time.csv** & **timeStr.csv** - -Dates csv file of current year folder, of size [#date]. *time.csv* recorded Matlab datenum and *timeStr.csv* recorded date in format of yyyy-mm-dd. - -Notice that each year start from and end before April 1st. For example data in folder 2010 is actually data from 2010-04-01 to 2011-03-31. The reason is that SMAP launched at April 1st. - -- **const/[Constant Variable Name].csv** - -csv file for constant variables of size [#grid]. - -- **crd.csv** - -Coordinate of all grids. First Column is latitude and second column is longitude. Each row refers a grid. - -### 2. Statistics folder - -Stored statistics of variables in order to do data normalization during training. Named as: -- Time dependent variables-> [variable name].csv -- Constant variables-> const_[variable name].csv - -Each file wrote four statistics of variable: -- 90 percentile -- 10 percentile -- mean -- std - -During training we normalize data by (data - mean) / std - -### 3. Subset folder -Subset refers to a subset of grids from the complete dataset (CONUS or Global). For example, a subset only contains grids in Pennsylvania. All subsets (including the CONUS or Global dataset) will have a *[subset name].csv* file in the *Subset* folder. *[subset name].csv* is wrote as: -- line 1 -> root dataset -- line 2 - end -> indexs of subset grids in rootset (start from 1) - -If the index is -1 means all grid, from example CONUS dataset. - -### 4. Variable folder -Stored csv files contains a list of variables. Used as input to training code. Time-dependent variables and constant variables should be stored seperately. For example: -- varLst.csv -> a list of time-dependent variables used as training predictors. -- varLst.csv -> a list of constant variables used as training predictors. +This code contains deep learning code used to modeling hydrologic systems, from soil moisture to streamflow, from projection to forecast. + +This released code depends on our hydroDL repository, please follow our original github repository where we will release new updates occasionally +https://github.com/mhpi/hydroDL +# Citations + +If you find our code to be useful, please cite the following papers: + +Feng, DP., Lawson, K., and CP. Shen, Mitigating prediction error of deep learning streamflow models in large data-sparse regions with ensemble modeling and soft data, Geophysical Research Letters (2021, Accepted) arXiv preprint https://arxiv.org/abs/2011.13380 + +Feng, DP, K. Fang and CP. Shen, Enhancing streamflow forecast and extracting insights using continental-scale long-short term memory networks with data integration, Water Resources Research (2020), https://doi.org/10.1029/2019WR026793 + +Shen, CP., A trans-disciplinary review of deep learning research and its relevance for water resources scientists, Water Resources Research. 54(11), 8558-8593, doi: 10.1029/2018WR022643 (2018) https://doi.org/10.1029/2018WR022643 + +Major code contributor: Dapeng Feng (PhD Student, Penn State) and Kuai Fang (PhD., Penn State) + +# Examples +The environment we are using is shown as the file `repoenv.yml`. To create the same conda environment, please run: + ```Shell +conda env create -f repoenv.yml +``` +Activate the installed environment before running the code: + ```Shell +conda activate mhpihydrodl +``` +You can also use this `Environment Setup_Tutorial.pdf` document as a reference to set up your environment and solve some frequently encountered questions. +There may be a small compatibility issue with our code when using very high pyTorch version. Welcome to contact us if you find any issue not able to solve or bug. + + +Several examples related to the above papers are presented here. **Click the title link** to see each example. +## [1.Train a LSTM data integration model to make streamflow forecast](example/StreamflowExample-DI.py) +The dataset used is NCAR CAMELS dataset. Download CAMELS following [this link](https://ral.ucar.edu/solutions/products/camels). +Please download both forcing, observation data `CAMELS time series meteorology, observed flow, meta data (.zip)` and basin attributes `CAMELS Attributes (.zip)`. +Put two unzipped folders under the same directory, like `your/path/to/Camels/basin_timeseries_v1p2_metForcing_obsFlow`, and `your/path/to/Camels/camels_attributes_v2.0`. Set the directory path `your/path/to/Camels` +as the variable `rootDatabase` inside the code later. + +Computational benchmark: training of CAMELS data (w/ or w/o data integration) with 671 basins, 10 years, 300 epochs, in ~1 hour with GPU. + +Related papers: +Feng et al. (2020). [Enhancing streamflow forecast and extracting insights using long‐short term memory networks with data integration at continental scales](https://doi.org/10.1029/2019WR026793). Water Resources Research. + +## [2.Train LSTM and CNN-LSTM models for prediction in ungauged regions](example/PUR/trainPUR-Reg.py) +The dataset used is also NCAR CAMELS. Follow the instructions in the first example above to download and unzip the dataset. Use [this code](example/PUR/testPUR-Reg.py) to test your saved models after training finished. + +Related papers: +Feng et al. (2021, Accepted). Mitigating prediction error of deep learning streamflow models in large data-sparse regions with ensemble modeling and soft data. Geophysical Research Letters. +Feng et al. (2020). [Enhancing streamflow forecast and extracting insights using long‐short term memory networks with data integration at continental scales](https://doi.org/10.1029/2019WR026793). Water Resources Research. + +## [3.Train a LSTM model to learn SMAP soil moisture](example/demo-LSTM-Tutorial.ipynb) +The example dataset is embedded in this repo and can be found [here](example/data). +You can also use [this script](example/train-lstm.py) to train model if you don't want to work with Jupyter Notebook. + +Related papers: +Fang et al. (2017), [Prolongation of SMAP to Spatio-temporally Seamless Coverage of Continental US Using a Deep Learning Neural Network](https://agupubs.onlinelibrary.wiley.com/doi/full/10.1002/2017GL075619), Geophysical Research Letters. + +## [4.Estimate uncertainty of a LSTM network ](example/train-lstm-mca.py) +Related papers: +Fang et al. (2020). [Evaluating the potential and challenges of an uncertainty quantification method for long short-term memory models for soil moisture predictions](https://agupubs.onlinelibrary.wiley.com/doi/10.1029/2020WR028095), Water Resources Research. +# License +Non-Commercial Software License Agreement + +By downloading the hydroDL software (the “Software”) you agree to +the following terms of use: +Copyright (c) 2020, The Pennsylvania State University (“PSU”). All rights reserved. + +1. PSU hereby grants to you a perpetual, nonexclusive and worldwide right, privilege and +license to use, reproduce, modify, display, and create derivative works of Software for all +non-commercial purposes only. You may not use Software for commercial purposes without +prior written consent from PSU. Queries regarding commercial licensing should be directed +to The Office of Technology Management at 814.865.6277 or otminfo@psu.edu. +2. Neither the name of the copyright holder nor the names of its contributors may be used +to endorse or promote products derived from this software without specific prior written +permission. +3. This software is provided for non-commercial use only. +4. Redistribution and use in source and binary forms, with or without modification, are +permitted provided that redistributions must reproduce the above copyright notice, license, +list of conditions and the following disclaimer in the documentation and/or other materials +provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/example/.ipynb_checkpoints/demo-temporal-test-checkpoint.ipynb b/example/.ipynb_checkpoints/demo-temporal-test-checkpoint.ipynb index 162a0f0..22c4bd5 100644 --- a/example/.ipynb_checkpoints/demo-temporal-test-checkpoint.ipynb +++ b/example/.ipynb_checkpoints/demo-temporal-test-checkpoint.ipynb @@ -26,7 +26,7 @@ "text": [ "loading package hydroDL\n", "/home/kxf227/work/GitHUB/pyRnnSMAP/example/data/Subset/CONUSv4f1.csv\n", - "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/SMAP_AM.csv 0.04537510871887207\n" + "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/SMAP_AM.csv 0.043489694595336914\n" ] } ], @@ -62,13 +62,13 @@ "output_type": "stream", "text": [ "/home/kxf227/work/GitHUB/pyRnnSMAP/example/data/Subset/CONUSv4f1.csv\n", - "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/APCP_FORA.csv 0.044591665267944336\n", - "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/DLWRF_FORA.csv 0.052686452865600586\n", - "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/DSWRF_FORA.csv 0.050998687744140625\n", - "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/TMP_2_FORA.csv 0.051717281341552734\n", - "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/SPFH_2_FORA.csv 0.05404353141784668\n", - "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/VGRD_10_FORA.csv 0.051822662353515625\n", - "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/UGRD_10_FORA.csv 0.0521092414855957\n" + "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/APCP_FORA.csv 0.04564547538757324\n", + "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/DLWRF_FORA.csv 0.05148744583129883\n", + "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/DSWRF_FORA.csv 0.051079750061035156\n", + "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/TMP_2_FORA.csv 0.05147600173950195\n", + "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/SPFH_2_FORA.csv 0.05448579788208008\n", + "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/VGRD_10_FORA.csv 0.05189824104309082\n", + "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/UGRD_10_FORA.csv 0.05310535430908203\n" ] } ], @@ -874,7 +874,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" diff --git a/example/PUR/testPUR-Reg.py b/example/PUR/testPUR-Reg.py new file mode 100644 index 0000000..7b525b5 --- /dev/null +++ b/example/PUR/testPUR-Reg.py @@ -0,0 +1,295 @@ +import sys +sys.path.append('../') +from hydroDL import master +from hydroDL.post import plot, stat +import matplotlib.pyplot as plt +from hydroDL.data import camels +from hydroDL.master import loadModel +from hydroDL.model import train + +import numpy as np +import pandas as pd +import torch +import json +import os +import random + + +interfaceOpt = 1 +# set it the same as which you used to train your model +# ==1 default, the recommended and more interpretable version +# ==0 the "pro" version + +# Define root directory of database and output +# Modify this based on your own location of CAMELS dataset and saved models + +rootDatabase = os.path.join(os.path.sep, 'scratch', 'Camels') # CAMELS dataset root directory +camels.initcamels(rootDatabase) # initialize three camels module-scope variables in camels.py: dirDB, gageDict, statDict + +rootOut = os.path.join(os.path.sep, 'data', 'rnnStreamflow') # Model output root directory + +# The directory you defined in training to save the model under the above rootOut +exp_name='PUR' +exp_disp='Testrun' +save_path = os.path.join(rootOut, exp_name, exp_disp) + +random.seed(159654) +# this random is only used for fractional FDC scenarios and +# to sample which basins in the target region have FDCs. + +# same as training, get the 7 regions basin ID for testing +gageinfo = camels.gageDict +hucinfo = gageinfo['huc'] +gageid = gageinfo['id'] +# get the id list of each region +regionID = list() +regionNum = list() +regionDivide = [ [1,2], [3,6], [4,5,7], [9,10], [8,11,12,13], [14,15,16,18], [17] ] # seven regions +for ii in range(len(regionDivide)): + tempcomb = regionDivide[ii] + tempregid = list() + for ih in tempcomb: + tempid = gageid[hucinfo==ih].tolist() + tempregid = tempregid + tempid + regionID.append(tempregid) + regionNum.append(len(tempregid)) + + +# test trained models +testgpuid = 0 # which gpu used for testing +torch.cuda.set_device(testgpuid) + +# # test models ran by different random seeds. +# # final reported results are the mean of predictions from different seeds +# seedid = [159654, 109958, 257886, 142365, 229837, 588859] +seedid = [159654] # take this seed as an example + +gageinfo = camels.gageDict +gageid = gageinfo['id'] +subName = 'Sub' +tRange = [19951001, 20051001] # testing periods +testEpoch = 300 # test the saved model after trained how many epochs +FDCMig = True # option for Fractional FDC experiments, migrating 1/3 or 1/10 FDCs to all basins in the target region +FDCrange = [19851001, 19951001] +FDCstr = str(FDCrange[0])+'-'+str(FDCrange[1]) + +expName = 'Full' # testing full-attribute model as an example. Could be 'Full', 'Noattr', '5attr' +caseLst = ['-'.join(['Reg-85-95', subName, expName])] # PUR: 'Reg-85-95-Sub-Full' +caseLst.append('-'.join(['Reg-85-95', subName, expName, 'FDC']) + FDCstr) # PUR with FDC: 'Reg-85-95-Sub-Full-FDC' + LCTstr + +# samFrac = [1/3, 1/10] # Fractional FDCs available in the target region +samFrac = [1/3] +migOptLst = [False, False] # the list indicating if it's migration experiment for each one in caseLst +if FDCMig == True: + for im in range(len(samFrac)): + caseLst.append('-'.join(['Reg-85-95', subName, expName, 'FDC']) + FDCstr) + migOptLst.append(True) +# caseLst summarizes all the experiment directories needed for testing + +# Get the randomly sampled basin ID in the target region that have FDCs and save to file for the future use +# the sampled basin should be the same for different ensemble members +# this part only needs to be ran once to generate sample file. The following section will read the saved file +sampleIndLst = list() +for ir in range(len(regionID)): + testBasin = regionID[ir] + sampNum = round(1/3 * len(testBasin)) # or 1/10 + sampleIn = random.sample(range(0, len(testBasin)), sampNum) + sampleIndLst.append(sampleIn) +samLstFile = os.path.join(save_path, 'samp103Lst.json') # or 'samp110Lst.json' +with open(samLstFile, 'w') as fp: + json.dump(sampleIndLst, fp, indent=4) + + +# Load the sample Ind +indfileLst = ['samp103Lst.json'] # ['samp103Lst.json', 'samp110Lst.json'] +sampleInLstAll = list() +for ii in range(len(indfileLst)): + samLstFile = os.path.join(save_path, indfileLst[ii]) + with open(samLstFile, 'r') as fp: + tempind = json.load(fp) + sampleInLstAll.append(tempind) + +for iEns in range(len(seedid)): #test trained models with different seeds + tempseed = seedid[iEns] + predtempLst = [] + regcount = 0 + + # for iT in range(len(regionID)): # test all the 7 regions + for iT in range(0, 1): # take region 1 as an example + testBasin = regionID[iT] # testing basins + testInd = [gageid.tolist().index(x) for x in testBasin] + trainBasin = list(set(gageid.tolist()) - set(testBasin)) + trainInd = [gageid.tolist().index(x) for x in trainBasin] + testregdic = 'Reg-'+str(iT+1)+'-Num'+str(regionNum[iT]) + + # Migrate FDC for fractional experiment based on the nearest distance + if FDCMig == True: + FDCList = [] + testlat = gageinfo['lat'][testInd] + testlon = gageinfo['lon'][testInd] + for iF in range(len(samFrac)): + sampleInLst = sampleInLstAll[iF] + samplelat = testlat[sampleInLst[iT]] + samplelon = testlon[sampleInLst[iT]] + nearID = list() + # calculate distances to the gages with FDC available + # and identify using the FDC of which gage for each test basin + for ii in range(len(testlat)): + dist = np.sqrt((samplelat-testlat[ii])**2 + (samplelon-testlon[ii])**2) + nearID.append(np.argmin(dist)) + FDCLS = gageid[testInd][sampleInLst[iT]][nearID].tolist() + FDCList.append(FDCLS) + + outLst = [os.path.join(save_path, str(tempseed), testregdic, x) for x in caseLst] + # all the directories to test in this list + + icount = 0 + imig = 0 + for out in outLst: + # testing sequence: LSTM, LSTM with FDC, LSTM with fractional FDC migration + if interfaceOpt == 1: + # load testing data + mDict = master.readMasterFile(out) + optData = mDict['data'] + df = camels.DataframeCamels( + subset=testBasin, tRange=tRange) + x = df.getDataTs( + varLst=optData['varT'], + doNorm=False, + rmNan=False) + obs = df.getDataObs( + doNorm=False, + rmNan=False, + basinnorm=False) + c = df.getDataConst( + varLst=optData['varC'], + doNorm=False, + rmNan=False) + + # do normalization and remove nan + # load the saved statDict + statFile = os.path.join(out, 'statDict.json') + with open(statFile, 'r') as fp: + statDict = json.load(fp) + seriesvarLst = optData['varT'] + climateList = optData['varC'] + attr_norm = camels.transNormbyDic(c, climateList, statDict, toNorm=True) + attr_norm[np.isnan(attr_norm)] = 0.0 + xTest = camels.transNormbyDic(x, seriesvarLst, statDict, toNorm=True) + xTest[np.isnan(xTest)] = 0.0 + if attr_norm.size == 0: # [], no-attribute case + attrs = None + else: + attrs = attr_norm + + + if optData['lckernel'] is not None: + if migOptLst[icount] is True: + # the case migrating FDCs + dffdc = camels.DataframeCamels(subset=FDCList[imig], tRange=optData['lckernel']) + imig = imig+1 + else: + dffdc = camels.DataframeCamels(subset=testBasin, tRange=optData['lckernel']) + datatemp = dffdc.getDataObs( + doNorm=False, rmNan=False, basinnorm=True) + # normalize data + dadata = camels.transNormbyDic(datatemp, 'runoff', statDict, toNorm=True) + dadata = np.squeeze(dadata) # dim Ngrid*Nday + fdcdata = master.master.calFDC(dadata) + print('FDC was calculated and used!') + xIn = (xTest, fdcdata) + else: + xIn = xTest + + # load and forward the model for testing + testmodel = loadModel(out, epoch=testEpoch) + filePathLst = master.master.namePred( + out, tRange, 'All', epoch=testEpoch) # prepare the name of csv files to save testing results + train.testModel( + testmodel, xIn, c=attrs, filePathLst=filePathLst) + # read out predictions + dataPred = np.ndarray([obs.shape[0], obs.shape[1], len(filePathLst)]) + for k in range(len(filePathLst)): + filePath = filePathLst[k] + dataPred[:, :, k] = pd.read_csv( + filePath, dtype=np.float, header=None).values + # transform back to the original observation + temppred = camels.transNormbyDic(dataPred, 'runoff', statDict, toNorm=False) + pred = camels.basinNorm(temppred, np.array(testBasin), toNorm=False) + + elif interfaceOpt == 0: + if migOptLst[icount] is True: + # for FDC migration case + df, pred, obs = master.test(out, tRange=tRange, subset=testBasin, basinnorm=True, epoch=testEpoch, + reTest=True, FDCgage=FDCList[imig]) + imig = imig + 1 + else: + # for other ordinary cases + df, pred, obs = master.test(out, tRange=tRange, subset=testBasin, basinnorm=True, epoch=testEpoch, + reTest=True) + + ## change the units ft3/s to m3/s + obs = obs*0.0283168 + pred = pred*0.0283168 + + # concatenate results in different regions to one array + # and save the array of different experiments to a list + if regcount == 0: + predtempLst.append(pred) + else: + predtempLst[icount] = np.concatenate([predtempLst[icount], pred], axis=0) + icount = icount + 1 + if regcount == 0: + obsAll = obs + else: + obsAll = np.concatenate([obsAll, obs], axis=0) + regcount = regcount+1 + + # concatenate results of different seeds to the third dim of array + if iEns == 0: + predLst = predtempLst + else: + for ii in range(len(outLst)): + predLst[ii] = np.concatenate([predLst[ii], predtempLst[ii]], axis=2) + # predLst: List of all experiments with shape: Ntime*Nbasin*Nensemble + +# get the ensemble mean from simulations of different seeds +ensLst = [] +for ii in range(len(outLst)): + temp = np.nanmean(predLst[ii], axis=2, keepdims=True) + ensLst.append(temp) + + +# plot boxplots for different experiments +statDictLst = [stat.statError(x.squeeze(), obsAll.squeeze()) for x in ensLst] +keyLst=['NSE', 'KGE'] # which metric to show +dataBox = list() +for iS in range(len(keyLst)): + statStr = keyLst[iS] + temp = list() + for k in range(len(statDictLst)): + data = statDictLst[k][statStr] + data = data[~np.isnan(data)] + temp.append(data) + dataBox.append(temp) + +plt.rcParams['font.size'] = 14 +labelname = ['PUR', 'PUR-FDC', 'PUR-1/3FDC'] +xlabel = ['NSE', 'KGE'] +fig = plot.plotBoxFig(dataBox, xlabel, labelname, sharey=False, figsize=(6, 5)) +fig.patch.set_facecolor('white') +fig.show() + +# save evaluation results +outpath = os.path.join(save_path, 'TestResults', expName) +if not os.path.isdir(outpath): + os.makedirs(outpath) + +EnsEvaFile = os.path.join(outpath, 'EnsEva'+str(testEpoch)+'.npy') +np.save(EnsEvaFile, statDictLst) + +obsFile = os.path.join(outpath, 'obs.npy') +np.save(obsFile, obsAll) + +predFile = os.path.join(outpath, 'pred'+str(testEpoch)+'.npy') +np.save(predFile, predLst) diff --git a/example/PUR/trainPUR-Reg.py b/example/PUR/trainPUR-Reg.py new file mode 100644 index 0000000..7a2afea --- /dev/null +++ b/example/PUR/trainPUR-Reg.py @@ -0,0 +1,289 @@ +import sys +sys.path.append('../') +from hydroDL import master +from hydroDL.master import default +from hydroDL.data import camels +from hydroDL.model import rnn, crit, train + + +import json +import os +import numpy as np +import torch +import random + + +# Options for different interface +interfaceOpt = 1 +# ==1 default, the improved and more interpretable version. It's easier to see the data flow, model setup and training +# process. Recommended for most users. +# ==0 the original "pro" version we used to run heavy jobs for the paper. It was later improved for clarity to obtain option 1. +# Results are very similar for two options and have little difference in computational performance. + +Action = [1, 2] +# Using Action options to control training different models +# 1: Train Base LSTM PUR Models without integrating any soft info +# 2: Train CNN-LSTM to integrate FDCs + +# Hyperparameters +EPOCH = 300 +BATCH_SIZE=100 +RHO=365 +HIDDENSIZE=256 +saveEPOCH = 10 # save model for every "saveEPOCH" epochs +Ttrain=[19851001, 19951001] # training period +LCrange = [19851001, 19951001] + +# Define root directory of database and output +# Modify this based on your own location of CAMELS dataset +# Following the data download instruction in README file, you should organize the folders like +# 'your/path/to/Camels/basin_timeseries_v1p2_metForcing_obsFlow' and 'your/path/to/Camels/camels_attributes_v2.0' +# Then 'rootDatabase' here should be 'your/path/to/Camels' +# You can also define the database directory in hydroDL/__init__.py by modifying pathCamels['DB'] variable + +rootDatabase = os.path.join(os.path.sep, 'scratch', 'Camels') # CAMELS dataset root directory +camels.initcamels(rootDatabase) # initialize three camels module-scope variables in camels.py: dirDB, gageDict, statDict + +rootOut = os.path.join(os.path.sep, 'data', 'rnnStreamflow') # Model output root directory + +# define random seed +# seedid = [159654, 109958, 257886, 142365, 229837, 588859] # six seeds randomly generated using np.random.uniform +seedid = 159654 +random.seed(seedid) +torch.manual_seed(seedid) +np.random.seed(seedid) +torch.cuda.manual_seed(seedid) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +# Fix seed for training, change it to have different runnings with different seeds +# We use the mean discharge of 6 runnings with different seeds to account for randomness + + +# directory to save training results +exp_name='PUR' +exp_disp='Testrun' +save_path = os.path.join(exp_name, exp_disp, str(seedid)) + +# Divide CAMELS dataset into 7 PUR regions +gageinfo = camels.gageDict +hucinfo = gageinfo['huc'] +gageid = gageinfo['id'] +# get the id list of each region +regionID = list() +regionNum = list() +regionDivide = [ [1,2], [3,6], [4,5,7], [9,10], [8,11,12,13], [14,15,16,18], [17] ] # seven regions +for ii in range(len(regionDivide)): + tempcomb = regionDivide[ii] + tempregid = list() + for ih in tempcomb: + tempid = gageid[hucinfo==ih].tolist() + tempregid = tempregid + tempid + regionID.append(tempregid) + regionNum.append(len(tempregid)) + +# Only for interfaceOpt=0 using multiple GPUs, not used here +# cid = 0 # starting GPU id +# gnum = 6 # how many GPUs you have + +# Region withheld as testing target. Take region 1 as an example. +# Change this to 1,2,..,7 to run models for all 7 PUR regions in CONUS. +testRegion = 1 + +iexp = testRegion - 1 # index +TestLS = regionID[iexp] # basin ID list for testing, should be withheld for training +TrainLS = list(set(gageid.tolist()) - set(TestLS)) # basin ID for training +gageDic = {'TrainID': TrainLS, 'TestID': TestLS} + +# prepare the training dataset +optData = default.optDataCamels +optData = default.update(optData, tRange=Ttrain, subset=TrainLS, lckernel=None, fdcopt=False) +climateList = camels.attrLstSel + ['p_mean','pet_mean','p_seasonality','frac_snow','aridity','high_prec_freq', + 'high_prec_dur','low_prec_freq','low_prec_dur'] +# climateList = ['slope_mean', 'area_gages2', 'frac_forest', 'soil_porosity', 'max_water_content'] +# climateList = [] +optData = default.update(optData, varT=camels.forcingLst, varC= climateList) +# varT: forcing used for training varC: attributes used for training + +# The above controls what attributes used for training, change varC for input-selection-ensemble +# for 5 attributes model: climateList = ['slope_mean', 'area_gages2', 'frac_forest', 'soil_porosity', 'max_water_content'] +# for no-attribute model: varC = [] +# the input-selection ensemble represents using the mean prediction of full, 5-attr and no-attr models, +# in total the mean of 3(different attributes)*6(different random seeds) = 18 models + +if interfaceOpt == 1: +# read data from CAMELS dataset + df = camels.DataframeCamels( + subset=optData['subset'], tRange=optData['tRange']) + x = df.getDataTs( + varLst=optData['varT'], + doNorm=False, + rmNan=False) + y = df.getDataObs( + doNorm=False, + rmNan=False, + basinnorm=True) + # "basinnorm = True" will call camels.basinNorm() on the original discharge data. This will transform discharge + # from ft3/s to mm/day and then divided by mean precip to be dimensionless. output = discharge/(area*mean_precip) + c = df.getDataConst( + varLst=optData['varC'], + doNorm=False, + rmNan=False) + + # process, do normalization and remove nan + series_data = np.concatenate([x, y], axis=2) + seriesvarLst = camels.forcingLst + ['runoff'] + # calculate statistics for normalization and save to a dictionary + statDict = camels.getStatDic(attrLst=climateList, attrdata=c, seriesLst=seriesvarLst, seriesdata=series_data) + # normalize + attr_norm = camels.transNormbyDic(c, climateList, statDict, toNorm=True) + attr_norm[np.isnan(attr_norm)] = 0.0 + series_norm = camels.transNormbyDic(series_data, seriesvarLst, statDict, toNorm=True) + + # prepare the inputs + xTrain = series_norm[:,:,:-1] # forcing, not include obs + xTrain[np.isnan(xTrain)] = 0.0 + yTrain = np.expand_dims(series_norm[:,:,-1], 2) + + if attr_norm.size == 0: # [], no-attribute case + attrs = None + Nx = xTrain.shape[-1] + else: + # with attributes + attrs=attr_norm + Nx = xTrain.shape[-1] + attrs.shape[-1] + Ny = yTrain.shape[-1] + +# define loss function +optLoss = default.optLossRMSE +lossFun = crit.RmseLoss() +# configuration for training +optTrain = default.update(default.optTrainCamels, miniBatch=[BATCH_SIZE, RHO], nEpoch=EPOCH, saveEpoch=saveEPOCH, seed=seedid) +hucdic = 'Reg-'+str(iexp+1)+'-Num'+str(regionNum[iexp]) + +if 1 in Action: +# Train base LSTM PUR model + out = os.path.join(rootOut, save_path, hucdic,'Reg-85-95-Sub-Full') + # out = os.path.join(rootOut, save_path, hucdic,'Reg-85-95-Sub-5attr') + # out = os.path.join(rootOut, save_path, hucdic,'Reg-85-95-Sub-Noattr') + if not os.path.isdir(out): + os.makedirs(out) + # log training gage information + gageFile = os.path.join(out, 'gage.json') + with open(gageFile, 'w') as fp: + json.dump(gageDic, fp, indent=4) + # define model config + optModel = default.update(default.optLstm, name='hydroDL.model.rnn.CudnnLstmModel', hiddenSize=HIDDENSIZE) + + if interfaceOpt == 1: + # define, load and train model + optModel = default.update(optModel, nx=Nx, ny=Ny) + model = rnn.CudnnLstmModel(nx=optModel['nx'], ny=optModel['ny'], hiddenSize=optModel['hiddenSize']) + # Wrap up all the training configurations to one dictionary in order to save into "out" folder + masterDict = master.wrapMaster(out, optData, optModel, optLoss, optTrain) + master.writeMasterFile(masterDict) + # log statistics + statFile = os.path.join(out, 'statDict.json') + with open(statFile, 'w') as fp: + json.dump(statDict, fp, indent=4) + # Train the model + trainedModel = train.trainModel( + model, + xTrain, + yTrain, + attrs, + lossFun, + nEpoch=EPOCH, + miniBatch=[BATCH_SIZE, RHO], + saveEpoch=saveEPOCH, + saveFolder=out) + + if interfaceOpt == 0: + # Only need to pass the wrapped configuration dict 'masterDict' for training + # nx, ny will be automatically updated later + masterDict = master.wrapMaster(out, optData, optModel, optLoss, optTrain) + master.train(masterDict) + + ## Not used here. + ## A potential way to run batch jobs simultaneously in background through multiple GPUs and Linux screens. + ## To use this, must manually set the "pathCamels['DB']" in hydroDL/__init__.py as your own root path of CAMELS data. + ## Use the following master.runTrain() instead of the above master.train(). + # master.runTrain(masterDict, cudaID=cid % gnum, screen='test-'+str(cid)) + # cid = cid + 1 + +if 2 in Action: +# Train CNN-LSTM PUR model to integrate FDCs + # LCrange defines from which period to get synthetic FDC + LCTstr = str(LCrange[0]) + '-' + str(LCrange[1]) + out = os.path.join(rootOut, save_path, hucdic, 'Reg-85-95-Sub-Full-FDC' + LCTstr) + # out = os.path.join(rootOut, save_path, hucdic, 'Reg-85-95-Sub-5attr-FDC' + LCTstr) + # out = os.path.join(rootOut, save_path, hucdic, 'Reg-85-95-Sub-Noattr-FDC' + LCTstr) + if not os.path.isdir(out): + os.makedirs(out) + gageFile = os.path.join(out, 'gage.json') + with open(gageFile, 'w') as fp: + json.dump(gageDic, fp, indent=4) + + optData = default.update(default.optDataCamels, tRange=Ttrain, subset=TrainLS, + lckernel=LCrange, fdcopt=True) + # define model + convNKS = [(10, 5, 1), (5, 3, 3), (1, 1, 1)] + # CNN parameters for 3 layers: [(Number of kernels 10,5,1), (kernel size 5,3,3), (stride 1,1,1)] + optModel = default.update(default.optCnn1dLstm, name='hydroDL.model.rnn.CNN1dLCmodel', + hiddenSize=HIDDENSIZE, convNKS=convNKS, poolOpt=[2, 2, 1]) # use CNN-LSTM model + + if interfaceOpt == 1: + # load data and create synthetic FDCs as inputs + dffdc = camels.DataframeCamels(subset=optData['subset'], tRange=optData['lckernel']) + datatemp = dffdc.getDataObs( + doNorm=False, rmNan=False, basinnorm=True) + # normalize data + dadata = camels.transNormbyDic(datatemp, 'runoff', statDict, toNorm=True) + dadata = np.squeeze(dadata) # dim Nbasin*Nday + fdcdata = master.master.calFDC(dadata) + print('FDC was calculated and used!') + xIn = (xTrain, fdcdata) + + # load model + Nobs = xIn[1].shape[-1] + optModel = default.update(optModel, nx=Nx, ny=Ny, nobs=Nobs) # update input dims + convpara = optModel['convNKS'] + + model = rnn.CNN1dLCmodel( + nx=optModel['nx'], + ny=optModel['ny'], + nobs=optModel['nobs'], + hiddenSize=optModel['hiddenSize'], + nkernel=convpara[0], + kernelSize=convpara[1], + stride=convpara[2], + poolOpt=optModel['poolOpt']) + print('CNN1d Local calibartion Kernel is used!') + + # Wrap up all the training configurations to one dictionary in order to save into "out" folder + masterDict = master.wrapMaster(out, optData, optModel, optLoss, optTrain) + master.writeMasterFile(masterDict) + # log statistics + statFile = os.path.join(out, 'statDict.json') + with open(statFile, 'w') as fp: + json.dump(statDict, fp, indent=4) + + # Train the model + trainedModel = train.trainModel( + model, + xIn, # need to well defined + yTrain, + attrs, + lossFun, + nEpoch=EPOCH, + miniBatch=[BATCH_SIZE, RHO], + saveEpoch=saveEPOCH, + saveFolder=out) + + if interfaceOpt == 0: + # Only need to pass the wrapped configuration 'masterDict' for training + # nx, ny, nobs will be automatically updated later + masterDict = master.wrapMaster(out, optData, optModel, optLoss, optTrain) + master.train(masterDict) # train model + + # master.runTrain(masterDict, cudaID=cid % gnum, screen='test-'+str(cid)) + # cid = cid + 1 diff --git a/example/StreamflowExample-DI.py b/example/StreamflowExample-DI.py new file mode 100644 index 0000000..ecdf089 --- /dev/null +++ b/example/StreamflowExample-DI.py @@ -0,0 +1,399 @@ +import sys +sys.path.append('../') +from hydroDL import master, utils +from hydroDL.master import default, loadModel +from hydroDL.post import plot, stat +import matplotlib.pyplot as plt +from hydroDL.data import camels +from hydroDL.model import rnn, crit, train + +import numpy as np +import pandas as pd +import os +import torch +import random +import datetime as dt +import json + +# Options for different interface +interfaceOpt = 1 +# ==1 default, the recommended and more interpretable version with clear data and training flow. We improved the +# original one to explicitly load and process data, set up model and loss, and train the model. +# ==0, the original "pro" version to train jobs based on the defined configuration dictionary. +# Results are very similar for two options. + +# Options for training and testing +# 0: train base model without DI +# 1: train DI model +# 0,1: do both base and DI model +# 2: test trained models +Action = [0,1] +# gpuid = 0 +# torch.cuda.set_device(gpuid) + +# Set hyperparameters +EPOCH = 300 +BATCH_SIZE = 100 +RHO = 365 +HIDDENSIZE = 256 +saveEPOCH = 10 # save model for every "saveEPOCH" epochs +Ttrain = [19851001, 19951001] # Training period + +# Fix random seed +seedid = 111111 +random.seed(seedid) +torch.manual_seed(seedid) +np.random.seed(seedid) +torch.cuda.manual_seed(seedid) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +# Change the seed to have different runnings. +# We use the mean discharge of 6 runnings with different seeds to account for randomness and report results + +# Define root directory of database and output +# Modify this based on your own location of CAMELS dataset. +# Following the data download instruction in README file, you should organize the folders like +# 'your/path/to/Camels/basin_timeseries_v1p2_metForcing_obsFlow' and 'your/path/to/Camels/camels_attributes_v2.0' +# Then 'rootDatabase' here should be 'your/path/to/Camels' +# You can also define the database directory in hydroDL/__init__.py by modifying pathCamels['DB'] variable +rootDatabase = os.path.join(os.path.sep, 'scratch', 'Camels') # CAMELS dataset root directory: /scratch/Camels +camels.initcamels(rootDatabase) # initialize three camels module-scope variables in camels.py: dirDB, gageDict, statDict + +rootOut = os.path.join(os.path.sep, 'data', 'rnnStreamflow') # Root directory to save training results: /data/rnnStreamflow + +# Define all the configurations into dictionary variables +# three purposes using these dictionaries. 1. saved as configuration logging file. 2. for future testing. 3. can also +# be used to directly train the model when interfaceOpt == 0 + +# define dataset +# default module stores default configurations, using update to change the config +optData = default.optDataCamels +optData = default.update(optData, varT=camels.forcingLst, varC=camels.attrLstSel, tRange=Ttrain) # Update the training period + +if (interfaceOpt == 1) and (2 not in Action): + # load training data explicitly for the interpretable interface. Notice: if you want to apply our codes to your own + # dataset, here is the place you can replace data. + # read data from original CAMELS dataset + # df: CAMELS dataframe; x: forcings[nb,nt,nx]; y: streamflow obs[nb,nt,ny]; c:attributes[nb,nc] + # nb: number of basins, nt: number of time steps (in Ttrain), nx: number of time-dependent forcing variables + # ny: number of target variables, nc: number of constant attributes + df = camels.DataframeCamels( + subset=optData['subset'], tRange=optData['tRange']) + x = df.getDataTs( + varLst=optData['varT'], + doNorm=False, + rmNan=False) + y = df.getDataObs( + doNorm=False, + rmNan=False, + basinnorm=False) + # transform discharge from ft3/s to mm/day and then divided by mean precip to be dimensionless. + # output = discharge/(area*mean_precip) + # this can also be done by setting the above option "basinnorm = True" for df.getDataObs() + y_temp = camels.basinNorm(y, optData['subset'], toNorm=True) + c = df.getDataConst( + varLst=optData['varC'], + doNorm=False, + rmNan=False) + + # process, do normalization and remove nan + series_data = np.concatenate([x, y_temp], axis=2) + seriesvarLst = camels.forcingLst + ['runoff'] + # calculate statistics for norm and saved to a dictionary + statDict = camels.getStatDic(attrLst=camels.attrLstSel, attrdata=c, seriesLst=seriesvarLst, seriesdata=series_data) + # normalize + attr_norm = camels.transNormbyDic(c, camels.attrLstSel, statDict, toNorm=True) + attr_norm[np.isnan(attr_norm)] = 0.0 + series_norm = camels.transNormbyDic(series_data, seriesvarLst, statDict, toNorm=True) + + # prepare the inputs + xTrain = series_norm[:, :, :-1] # forcing, not include obs + xTrain[np.isnan(xTrain)] = 0.0 + yTrain = np.expand_dims(series_norm[:, :, -1], 2) + attrs = attr_norm + + +# define model and update configure +if torch.cuda.is_available(): + optModel = default.optLstm +else: + optModel = default.update( + default.optLstm, + name='hydroDL.model.rnn.CpuLstmModel') +optModel = default.update(default.optLstm, hiddenSize=HIDDENSIZE) +# define loss function +optLoss = default.optLossRMSE +# define training options +optTrain = default.update(default.optTrainCamels, miniBatch=[BATCH_SIZE, RHO], nEpoch=EPOCH, saveEpoch=saveEPOCH, seed=seedid) + +# define output folder for model results +exp_name = 'CAMELSDemo' +exp_disp = 'TestRun' +save_path = os.path.join(exp_name, exp_disp, \ + 'epochs{}_batch{}_rho{}_hiddensize{}_Tstart{}_Tend{}'.format(optTrain['nEpoch'], optTrain['miniBatch'][0], + optTrain['miniBatch'][1], + optModel['hiddenSize'], + optData['tRange'][0], optData['tRange'][1])) + +# Train the base model without data integration +if 0 in Action: + out = os.path.join(rootOut, save_path, 'All-85-95') # output folder to save results + # Wrap up all the training configurations to one dictionary in order to save into "out" folder + masterDict = master.wrapMaster(out, optData, optModel, optLoss, optTrain) + if interfaceOpt == 1: # use the more interpretable version interface + nx = xTrain.shape[-1] + attrs.shape[-1] # update nx, nx = nx + nc + ny = yTrain.shape[-1] + # load model for training + if torch.cuda.is_available(): + model = rnn.CudnnLstmModel(nx=nx, ny=ny, hiddenSize=HIDDENSIZE) + else: + model = rnn.CpuLstmModel(nx=nx, ny=ny, hiddenSize=HIDDENSIZE) + optModel = default.update(optModel, nx=nx, ny=ny) + # the loaded model should be consistent with the 'name' in optModel Dict above for logging purpose + lossFun = crit.RmseLoss() + # the loaded loss should be consistent with the 'name' in optLoss Dict above for logging purpose + # update and write the dictionary variable to out folder for logging and future testing + masterDict = master.wrapMaster(out, optData, optModel, optLoss, optTrain) + master.writeMasterFile(masterDict) + # log statistics + statFile = os.path.join(out, 'statDict.json') + with open(statFile, 'w') as fp: + json.dump(statDict, fp, indent=4) + # train model + model = train.trainModel( + model, + xTrain, + yTrain, + attrs, + lossFun, + nEpoch=EPOCH, + miniBatch=[BATCH_SIZE, RHO], + saveEpoch=saveEPOCH, + saveFolder=out) + elif interfaceOpt==0: # directly train the model using dictionary variable + master.train(masterDict) + + +# Train DI model +if 1 in Action: + nDayLst = [1,3] + for nDay in nDayLst: + # nDay: previous Nth day observation to integrate + # update parameter "daObs" for data dictionary variable + optData = default.update(default.optDataCamels, daObs=nDay) + # define output folder for DI models + out = os.path.join(rootOut, save_path, 'All-85-95-DI' + str(nDay)) + masterDict = master.wrapMaster(out, optData, optModel, optLoss, optTrain) + if interfaceOpt==1: + # optData['daObs'] != 0, load previous observation data to integrate + sd = utils.time.t2dt( + optData['tRange'][0]) - dt.timedelta(days=nDay) + ed = utils.time.t2dt( + optData['tRange'][1]) - dt.timedelta(days=nDay) + dfdi = camels.DataframeCamels( + subset=optData['subset'], tRange=[sd, ed]) + datatemp = dfdi.getDataObs( + doNorm=False, rmNan=False, basinnorm=True) # 'basinnorm=True': output = discharge/(area*mean_precip) + # normalize data + dadata = camels.transNormbyDic(datatemp, 'runoff', statDict, toNorm=True) + dadata[np.where(np.isnan(dadata))] = 0.0 + + xIn = np.concatenate([xTrain, dadata], axis=2) + nx = xIn.shape[-1] + attrs.shape[-1] # update nx, nx = nx + nc + ny = yTrain.shape[-1] + # load model for training + if torch.cuda.is_available(): + model = rnn.CudnnLstmModel(nx=nx, ny=ny, hiddenSize=HIDDENSIZE) + else: + model = rnn.CpuLstmModel(nx=nx, ny=ny, hiddenSize=HIDDENSIZE) + optModel = default.update(optModel, nx=nx, ny=ny) + lossFun = crit.RmseLoss() + # update and write dictionary variable to out folder for logging and future testing + masterDict = master.wrapMaster(out, optData, optModel, optLoss, optTrain) + master.writeMasterFile(masterDict) + # log statistics + statFile = os.path.join(out, 'statDict.json') + with open(statFile, 'w') as fp: + json.dump(statDict, fp, indent=4) + # train model + model = train.trainModel( + model, + xIn, + yTrain, + attrs, + lossFun, + nEpoch=EPOCH, + miniBatch=[BATCH_SIZE, RHO], + saveEpoch=saveEPOCH, + saveFolder=out) + elif interfaceOpt==0: + master.train(masterDict) + +# Test models +if 2 in Action: + TestEPOCH = 300 # choose the model to test after trained "TestEPOCH" epoches + # generate a folder name list containing all the tested model output folders + caseLst = ['All-85-95'] + nDayLst = [1, 3] # which DI models to test: DI(1), DI(3) + for nDay in nDayLst: + caseLst.append('All-85-95-DI' + str(nDay)) + outLst = [os.path.join(rootOut, save_path, x) for x in caseLst] # outLst includes all the directories to test + subset = 'All' # 'All': use all the CAMELS gages to test; Or pass the gage list + tRange = [19951001, 20051001] # Testing period + testBatch = 100 # do batch forward to save GPU memory + predLst = list() + for out in outLst: + if interfaceOpt == 1: # use the more interpretable version interface + # load testing data + mDict = master.readMasterFile(out) + optData = mDict['data'] + df = camels.DataframeCamels( + subset=subset, tRange=tRange) + x = df.getDataTs( + varLst=optData['varT'], + doNorm=False, + rmNan=False) + obs = df.getDataObs( + doNorm=False, + rmNan=False, + basinnorm=False) + c = df.getDataConst( + varLst=optData['varC'], + doNorm=False, + rmNan=False) + + # do normalization and remove nan + # load the saved statDict to make sure using the same statistics as training data + statFile = os.path.join(out, 'statDict.json') + with open(statFile, 'r') as fp: + statDict = json.load(fp) + seriesvarLst = optData['varT'] + attrLst = optData['varC'] + attr_norm = camels.transNormbyDic(c, attrLst, statDict, toNorm=True) + attr_norm[np.isnan(attr_norm)] = 0.0 + xTest = camels.transNormbyDic(x, seriesvarLst, statDict, toNorm=True) + xTest[np.isnan(xTest)] = 0.0 + attrs = attr_norm + + if optData['daObs'] > 0: + # optData['daObs'] != 0, load previous observation data to integrate + nDay = optData['daObs'] + sd = utils.time.t2dt( + tRange[0]) - dt.timedelta(days=nDay) + ed = utils.time.t2dt( + tRange[1]) - dt.timedelta(days=nDay) + dfdi = camels.DataframeCamels( + subset=subset, tRange=[sd, ed]) + datatemp = dfdi.getDataObs( + doNorm=False, rmNan=False, basinnorm=True) # 'basinnorm=True': output = discharge/(area*mean_precip) + # normalize data + dadata = camels.transNormbyDic(datatemp, 'runoff', statDict, toNorm=True) + dadata[np.where(np.isnan(dadata))] = 0.0 + xIn = np.concatenate([xTest, dadata], axis=2) + + else: + xIn = xTest + + # load and forward the model for testing + testmodel = loadModel(out, epoch=TestEPOCH) + filePathLst = master.master.namePred( + out, tRange, 'All', epoch=TestEPOCH) # prepare the name of csv files to save testing results + train.testModel( + testmodel, xIn, c=attrs, batchSize=testBatch, filePathLst=filePathLst) + # read out predictions + dataPred = np.ndarray([obs.shape[0], obs.shape[1], len(filePathLst)]) + for k in range(len(filePathLst)): + filePath = filePathLst[k] + dataPred[:, :, k] = pd.read_csv( + filePath, dtype=np.float, header=None).values + # transform back to the original observation + temppred = camels.transNormbyDic(dataPred, 'runoff', statDict, toNorm=False) + pred = camels.basinNorm(temppred, subset, toNorm=False) + + elif interfaceOpt == 0: # only for models trained by the pro interface + df, pred, obs = master.test(out, tRange=tRange, subset=subset, batchSize=testBatch, basinnorm=True, + epoch=TestEPOCH, reTest=True) + + # change the units ft3/s to m3/s + obs = obs * 0.0283168 + pred = pred * 0.0283168 + predLst.append(pred) # the prediction list for all the models + + # calculate statistic metrics + statDictLst = [stat.statError(x.squeeze(), obs.squeeze()) for x in predLst] + + # Show boxplots of the results + plt.rcParams['font.size'] = 14 + keyLst = ['Bias', 'NSE', 'FLV', 'FHV'] + dataBox = list() + for iS in range(len(keyLst)): + statStr = keyLst[iS] + temp = list() + for k in range(len(statDictLst)): + data = statDictLst[k][statStr] + data = data[~np.isnan(data)] + temp.append(data) + dataBox.append(temp) + labelname = ['LSTM'] + for nDay in nDayLst: + labelname.append('DI(' + str(nDay) + ')') + xlabel = ['Bias ($\mathregular{m^3}$/s)', 'NSE', 'FLV(%)', 'FHV(%)'] + fig = plot.plotBoxFig(dataBox, xlabel, labelname, sharey=False, figsize=(12, 5)) + fig.patch.set_facecolor('white') + fig.show() + # plt.savefig(os.path.join(rootOut, save_path, "Boxplot.png"), dpi=500) + + # Plot timeseries and locations + plt.rcParams['font.size'] = 12 + # get Camels gages info + gageinfo = camels.gageDict + gagelat = gageinfo['lat'] + gagelon = gageinfo['lon'] + # randomly select 7 gages to plot + gageindex = np.random.randint(0, 671, size=7).tolist() + plat = gagelat[gageindex] + plon = gagelon[gageindex] + t = utils.time.tRange2Array(tRange) + fig, axes = plt.subplots(4,2, figsize=(12,10), constrained_layout=True) + axes = axes.flat + npred = 2 # plot the first two prediction: Base LSTM and DI(1) + subtitle = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(k)', '(l)'] + txt = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'k'] + ylabel = 'Flow rate ($\mathregular{m^3}$/s)' + for k in range(len(gageindex)): + iGrid = gageindex[k] + yPlot = [obs[iGrid, :]] + for y in predLst[0:npred]: + yPlot.append(y[iGrid, :]) + # get the NSE value of LSTM and DI(1) model + NSE_LSTM = str(round(statDictLst[0]['NSE'][iGrid], 2)) + NSE_DI1 = str(round(statDictLst[1]['NSE'][iGrid], 2)) + # plot time series + plot.plotTS( + t, + yPlot, + ax=axes[k], + cLst='kbrmg', + markerLst='---', + legLst=['USGS', 'LSTM: '+NSE_LSTM, 'DI(1): '+NSE_DI1], title=subtitle[k], linespec=['-',':',':'], ylabel=ylabel) + # plot gage location + plot.plotlocmap(plat, plon, ax=axes[-1], baclat=gagelat, baclon=gagelon, title=subtitle[-1], txtlabel=txt) + fig.patch.set_facecolor('white') + fig.show() + # plt.savefig(os.path.join(rootOut, save_path, "/Timeseries.png"), dpi=500) + + # Plot NSE spatial patterns + gageinfo = camels.gageDict + gagelat = gageinfo['lat'] + gagelon = gageinfo['lon'] + nDayLst = [1, 3] + fig, axs = plt.subplots(3,1, figsize=(8,8), constrained_layout=True) + axs = axs.flat + data = statDictLst[0]['NSE'] + plot.plotMap(data, ax=axs[0], lat=gagelat, lon=gagelon, title='(a) LSTM', cRange=[0.0, 1.0], shape=None) + data = statDictLst[1]['NSE'] + plot.plotMap(data, ax=axs[1], lat=gagelat, lon=gagelon, title='(b) DI(1)', cRange=[0.0, 1.0], shape=None) + deltaNSE = statDictLst[1]['NSE'] - statDictLst[0]['NSE'] + plot.plotMap(deltaNSE, ax=axs[2], lat=gagelat, lon=gagelon, title='(c) Delta NSE', shape=None) + fig.show() + # plt.savefig(os.path.join(rootOut, save_path, "/NSEPattern.png"), dpi=500) diff --git a/example/data/CONUSv4f1/2015/APCP_FORA.csv b/example/data/CONUSv4f1/2015/APCP_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2015/DLWRF_FORA.csv b/example/data/CONUSv4f1/2015/DLWRF_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2015/DSWRF_FORA.csv b/example/data/CONUSv4f1/2015/DSWRF_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2015/SMAP_AM.csv b/example/data/CONUSv4f1/2015/SMAP_AM.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2015/SPFH_2_FORA.csv b/example/data/CONUSv4f1/2015/SPFH_2_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2015/TMP_2_FORA.csv b/example/data/CONUSv4f1/2015/TMP_2_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2015/UGRD_10_FORA.csv b/example/data/CONUSv4f1/2015/UGRD_10_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2015/VGRD_10_FORA.csv b/example/data/CONUSv4f1/2015/VGRD_10_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2015/time.csv b/example/data/CONUSv4f1/2015/time.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2015/timeStr.csv b/example/data/CONUSv4f1/2015/timeStr.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2016/APCP_FORA.csv b/example/data/CONUSv4f1/2016/APCP_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2016/DLWRF_FORA.csv b/example/data/CONUSv4f1/2016/DLWRF_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2016/DSWRF_FORA.csv b/example/data/CONUSv4f1/2016/DSWRF_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2016/SMAP_AM.csv b/example/data/CONUSv4f1/2016/SMAP_AM.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2016/SPFH_2_FORA.csv b/example/data/CONUSv4f1/2016/SPFH_2_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2016/TMP_2_FORA.csv b/example/data/CONUSv4f1/2016/TMP_2_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2016/UGRD_10_FORA.csv b/example/data/CONUSv4f1/2016/UGRD_10_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2016/VGRD_10_FORA.csv b/example/data/CONUSv4f1/2016/VGRD_10_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2016/time.csv b/example/data/CONUSv4f1/2016/time.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2016/timeStr.csv b/example/data/CONUSv4f1/2016/timeStr.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2017/APCP_FORA.csv b/example/data/CONUSv4f1/2017/APCP_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2017/DLWRF_FORA.csv b/example/data/CONUSv4f1/2017/DLWRF_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2017/DSWRF_FORA.csv b/example/data/CONUSv4f1/2017/DSWRF_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2017/SPFH_2_FORA.csv b/example/data/CONUSv4f1/2017/SPFH_2_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2017/TMP_2_FORA.csv b/example/data/CONUSv4f1/2017/TMP_2_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2017/UGRD_10_FORA.csv b/example/data/CONUSv4f1/2017/UGRD_10_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2017/VGRD_10_FORA.csv b/example/data/CONUSv4f1/2017/VGRD_10_FORA.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2017/time.csv b/example/data/CONUSv4f1/2017/time.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/2017/timeStr.csv b/example/data/CONUSv4f1/2017/timeStr.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/const/Bulk.csv b/example/data/CONUSv4f1/const/Bulk.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/const/Capa.csv b/example/data/CONUSv4f1/const/Capa.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/const/Clay.csv b/example/data/CONUSv4f1/const/Clay.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/const/NDVI.csv b/example/data/CONUSv4f1/const/NDVI.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/const/Sand.csv b/example/data/CONUSv4f1/const/Sand.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/const/Silt.csv b/example/data/CONUSv4f1/const/Silt.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/const/flag_albedo.csv b/example/data/CONUSv4f1/const/flag_albedo.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/const/flag_extraOrd.csv b/example/data/CONUSv4f1/const/flag_extraOrd.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/const/flag_landcover.csv b/example/data/CONUSv4f1/const/flag_landcover.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/const/flag_roughness.csv b/example/data/CONUSv4f1/const/flag_roughness.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/const/flag_vegDense.csv b/example/data/CONUSv4f1/const/flag_vegDense.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/const/flag_waterbody.csv b/example/data/CONUSv4f1/const/flag_waterbody.csv old mode 100755 new mode 100644 diff --git a/example/data/CONUSv4f1/crd.csv b/example/data/CONUSv4f1/crd.csv old mode 100755 new mode 100644 diff --git a/example/data/DATA.md b/example/data/DATA.md new file mode 100644 index 0000000..b17eeea --- /dev/null +++ b/example/data/DATA.md @@ -0,0 +1,83 @@ +# Database description +## Database Structure +``` +├── CONUS +│   ├── 2000 +│   │   ├── [Variable-Name].csv +│   │   ├── ... +│   │   ├── timeStr.csv +│   │   └── time.csv +│   ├── ... +│   ├── 2017 +│   │   └── ... +│   ├── const +│   │   ├── [Constant-Variable-Name].csv +│   │   └── ... +│   └── crd.csv +├── CONUSv4f1 +│   └── ... +├── Statistics +│   ├── [Variable-Name]_stat.csv +│   ├── ... +│   ├── const_[Constant-Variable-Name]_stat.csv +│   └── ... +├── Subset +│   ├── CONUS.csv +│   └── CONUSv4f1.csv +└── Variable + ├── varConstLst.csv + └── varLst.csv +``` +### 1. Dataset folders (*CONUS* , *CONUSv4f1*) +Data folder contains all data including both training and testing, time-dependent variables and constant variables. +In example data structure, there are two dataset folders - *CONUS* and *CONUSv4f1*. Those data are saved in: + + - **year/[Variable-Name].csv**: + +A csv file of size [#grid, #time], where each column is one grid and each row is one time step. This file saved data of a time-dependent variable of current year. For example, *CONUS/2010/SMAP_AM.csv* is SMAP data of 2002 on the CONUS. + +Most time-dependent varibles comes from NLDAS, which included two forcing product (FORA, FORB) and three simulations product land surface models (NOAH, MOS, VIC). Variables are named as *[variable]\_[product]\_[layer]*, and reference of variable can be found in [NLDAS document](https://hydro1.gesdisc.eosdis.nasa.gov/data/NLDAS/README.NLDAS2.pdf). For example, *SOILM_NOAH_0-10* refers to soil moisture product simulated by NOAH model at 0-10 cm. + +Other than NLDAS, SMAP data are also saved in same format but always used as target. In level 3 database, there are two SMAP csv files which are only available after 2015: *SMAP_AM.csv* and *SMAP_PM.csv*. + +-9999 refers to NaN. + +- **year/time.csv** & **timeStr.csv** + +Dates csv file of current year folder, of size [#date]. *time.csv* recorded Matlab datenum and *timeStr.csv* recorded date in format of yyyy-mm-dd. + +Notice that each year start from and end before April 1st. For example data in folder 2010 is actually data from 2010-04-01 to 2011-03-31. The reason is that SMAP launched at April 1st. + +- **const/[Constant Variable Name].csv** + +csv file for constant variables of size [#grid]. + +- **crd.csv** + +Coordinate of all grids. First Column is latitude and second column is longitude. Each row refers a grid. + +### 2. Statistics folder + +Stored statistics of variables in order to do data normalization during training. Named as: +- Time dependent variables-> [variable name].csv +- Constant variables-> const_[variable name].csv + +Each file wrote four statistics of variable: +- 90 percentile +- 10 percentile +- mean +- std + +During training we normalize data by (data - mean) / std + +### 3. Subset folder +Subset refers to a subset of grids from the complete dataset (CONUS or Global). For example, a subset only contains grids in Pennsylvania. All subsets (including the CONUS or Global dataset) will have a *[subset name].csv* file in the *Subset* folder. *[subset name].csv* is wrote as: +- line 1 -> root dataset +- line 2 - end -> indexs of subset grids in rootset (start from 1) + +If the index is -1 means all grid, from example CONUS dataset. + +### 4. Variable folder +Stored csv files contains a list of variables. Used as input to training code. Time-dependent variables and constant variables should be stored seperately. For example: +- varLst.csv -> a list of time-dependent variables used as training predictors. +- varLst.csv -> a list of constant variables used as training predictors. diff --git a/example/data/Statistics/APCP_FORA_stat.csv b/example/data/Statistics/APCP_FORA_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/DLWRF_FORA_stat.csv b/example/data/Statistics/DLWRF_FORA_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/DSWRF_FORA_stat.csv b/example/data/Statistics/DSWRF_FORA_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/SMAP_AM_stat.csv b/example/data/Statistics/SMAP_AM_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/SPFH_2_FORA_stat.csv b/example/data/Statistics/SPFH_2_FORA_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/TMP_2_FORA_stat.csv b/example/data/Statistics/TMP_2_FORA_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/UGRD_10_FORA_stat.csv b/example/data/Statistics/UGRD_10_FORA_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/VGRD_10_FORA_stat.csv b/example/data/Statistics/VGRD_10_FORA_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/const_Bulk_stat.csv b/example/data/Statistics/const_Bulk_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/const_Capa_stat.csv b/example/data/Statistics/const_Capa_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/const_Clay_stat.csv b/example/data/Statistics/const_Clay_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/const_NDVI_stat.csv b/example/data/Statistics/const_NDVI_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/const_Sand_stat.csv b/example/data/Statistics/const_Sand_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/const_Silt_stat.csv b/example/data/Statistics/const_Silt_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/const_flag_albedo_stat.csv b/example/data/Statistics/const_flag_albedo_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/const_flag_extraOrd_stat.csv b/example/data/Statistics/const_flag_extraOrd_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/const_flag_landcover_stat.csv b/example/data/Statistics/const_flag_landcover_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/const_flag_roughness_stat.csv b/example/data/Statistics/const_flag_roughness_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/const_flag_vegDense_stat.csv b/example/data/Statistics/const_flag_vegDense_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Statistics/const_flag_waterbody_stat.csv b/example/data/Statistics/const_flag_waterbody_stat.csv old mode 100755 new mode 100644 diff --git a/example/data/Subset/CONUSv4f1.csv b/example/data/Subset/CONUSv4f1.csv old mode 100755 new mode 100644 diff --git a/example/data/Variable/varConstLst.csv b/example/data/Variable/varConstLst.csv old mode 100755 new mode 100644 diff --git a/example/data/Variable/varConstLst_Noah.csv b/example/data/Variable/varConstLst_Noah.csv old mode 100755 new mode 100644 diff --git a/example/data/Variable/varLst.csv b/example/data/Variable/varLst.csv old mode 100755 new mode 100644 diff --git a/example/data/Variable/varLst_APCP_rn1e0.csv b/example/data/Variable/varLst_APCP_rn1e0.csv old mode 100755 new mode 100644 diff --git a/example/data/Variable/varLst_APCP_rn1e1.csv b/example/data/Variable/varLst_APCP_rn1e1.csv old mode 100755 new mode 100644 diff --git a/example/data/Variable/varLst_APCP_rn2e0.csv b/example/data/Variable/varLst_APCP_rn2e0.csv old mode 100755 new mode 100644 diff --git a/example/data/Variable/varLst_APCP_rn2e1.csv b/example/data/Variable/varLst_APCP_rn2e1.csv old mode 100755 new mode 100644 diff --git a/example/data/Variable/varLst_APCP_rn3e1.csv b/example/data/Variable/varLst_APCP_rn3e1.csv old mode 100755 new mode 100644 diff --git a/example/data/Variable/varLst_APCP_rn4e1.csv b/example/data/Variable/varLst_APCP_rn4e1.csv old mode 100755 new mode 100644 diff --git a/example/data/Variable/varLst_APCP_rn5e1.csv b/example/data/Variable/varLst_APCP_rn5e1.csv old mode 100755 new mode 100644 diff --git a/example/data/Variable/varLst_APCP_rn5e2.csv b/example/data/Variable/varLst_APCP_rn5e2.csv old mode 100755 new mode 100644 diff --git a/example/data/Variable/varLst_Forcing.csv b/example/data/Variable/varLst_Forcing.csv old mode 100755 new mode 100644 diff --git a/example/data/Variable/varLst_Forcing_noAPCP.csv b/example/data/Variable/varLst_Forcing_noAPCP.csv old mode 100755 new mode 100644 diff --git a/example/data/Variable/varLst_Forcing_noSPFH.csv b/example/data/Variable/varLst_Forcing_noSPFH.csv old mode 100755 new mode 100644 diff --git a/example/data/Variable/varLst_soilM.csv b/example/data/Variable/varLst_soilM.csv old mode 100755 new mode 100644 diff --git a/example/demo-LSTM-Tutorial.ipynb b/example/demo-LSTM-Tutorial.ipynb index 644744e..55e7277 100644 --- a/example/demo-LSTM-Tutorial.ipynb +++ b/example/demo-LSTM-Tutorial.ipynb @@ -114,7 +114,6 @@ { "ename": "KeyboardInterrupt", "evalue": "", - "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", @@ -124,7 +123,8 @@ "\u001b[1;32mC:\\pythonenvir\\lib\\site-packages\\torch\\tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[0;32m 116\u001b[0m \u001b[0mproducts\u001b[0m\u001b[1;33m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 117\u001b[0m \"\"\"\n\u001b[1;32m--> 118\u001b[1;33m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 119\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 120\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mC:\\pythonenvir\\lib\\site-packages\\torch\\autograd\\__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[0;32m 91\u001b[0m Variable._execution_engine.run_backward(\n\u001b[0;32m 92\u001b[0m \u001b[0mtensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 93\u001b[1;33m allow_unreachable=True) # allow_unreachable flag\n\u001b[0m\u001b[0;32m 94\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 95\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mKeyboardInterrupt\u001b[0m: " - ] + ], + "output_type": "error" } ], "source": [ @@ -214,777 +214,6 @@ "outputs": [ { "data": { - "application/javascript": [ - "/* Put everything inside the global mpl namespace */\n", - "window.mpl = {};\n", - "\n", - "\n", - "mpl.get_websocket_type = function() {\n", - " if (typeof(WebSocket) !== 'undefined') {\n", - " return WebSocket;\n", - " } else if (typeof(MozWebSocket) !== 'undefined') {\n", - " return MozWebSocket;\n", - " } else {\n", - " alert('Your browser does not have WebSocket support. ' +\n", - " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", - " 'Firefox 4 and 5 are also supported but you ' +\n", - " 'have to enable WebSockets in about:config.');\n", - " };\n", - "}\n", - "\n", - "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", - " this.id = figure_id;\n", - "\n", - " this.ws = websocket;\n", - "\n", - " this.supports_binary = (this.ws.binaryType != undefined);\n", - "\n", - " if (!this.supports_binary) {\n", - " var warnings = document.getElementById(\"mpl-warnings\");\n", - " if (warnings) {\n", - " warnings.style.display = 'block';\n", - " warnings.textContent = (\n", - " \"This browser does not support binary websocket messages. \" +\n", - " \"Performance may be slow.\");\n", - " }\n", - " }\n", - "\n", - " this.imageObj = new Image();\n", - "\n", - " this.context = undefined;\n", - " this.message = undefined;\n", - " this.canvas = undefined;\n", - " this.rubberband_canvas = undefined;\n", - " this.rubberband_context = undefined;\n", - " this.format_dropdown = undefined;\n", - "\n", - " this.image_mode = 'full';\n", - "\n", - " this.root = $('
');\n", - " this._root_extra_style(this.root)\n", - " this.root.attr('style', 'display: inline-block');\n", - "\n", - " $(parent_element).append(this.root);\n", - "\n", - " this._init_header(this);\n", - " this._init_canvas(this);\n", - " this._init_toolbar(this);\n", - "\n", - " var fig = this;\n", - "\n", - " this.waiting = false;\n", - "\n", - " this.ws.onopen = function () {\n", - " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", - " fig.send_message(\"send_image_mode\", {});\n", - " if (mpl.ratio != 1) {\n", - " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", - " }\n", - " fig.send_message(\"refresh\", {});\n", - " }\n", - "\n", - " this.imageObj.onload = function() {\n", - " if (fig.image_mode == 'full') {\n", - " // Full images could contain transparency (where diff images\n", - " // almost always do), so we need to clear the canvas so that\n", - " // there is no ghosting.\n", - " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", - " }\n", - " fig.context.drawImage(fig.imageObj, 0, 0);\n", - " };\n", - "\n", - " this.imageObj.onunload = function() {\n", - " fig.ws.close();\n", - " }\n", - "\n", - " this.ws.onmessage = this._make_on_message_function(this);\n", - "\n", - " this.ondownload = ondownload;\n", - "}\n", - "\n", - "mpl.figure.prototype._init_header = function() {\n", - " var titlebar = $(\n", - " '
');\n", - " var titletext = $(\n", - " '
');\n", - " titlebar.append(titletext)\n", - " this.root.append(titlebar);\n", - " this.header = titletext[0];\n", - "}\n", - "\n", - "\n", - "\n", - "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", - "\n", - "}\n", - "\n", - "\n", - "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", - "\n", - "}\n", - "\n", - "mpl.figure.prototype._init_canvas = function() {\n", - " var fig = this;\n", - "\n", - " var canvas_div = $('
');\n", - "\n", - " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", - "\n", - " function canvas_keyboard_event(event) {\n", - " return fig.key_event(event, event['data']);\n", - " }\n", - "\n", - " canvas_div.keydown('key_press', canvas_keyboard_event);\n", - " canvas_div.keyup('key_release', canvas_keyboard_event);\n", - " this.canvas_div = canvas_div\n", - " this._canvas_extra_style(canvas_div)\n", - " this.root.append(canvas_div);\n", - "\n", - " var canvas = $('');\n", - " canvas.addClass('mpl-canvas');\n", - " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", - "\n", - " this.canvas = canvas[0];\n", - " this.context = canvas[0].getContext(\"2d\");\n", - "\n", - " var backingStore = this.context.backingStorePixelRatio ||\n", - "\tthis.context.webkitBackingStorePixelRatio ||\n", - "\tthis.context.mozBackingStorePixelRatio ||\n", - "\tthis.context.msBackingStorePixelRatio ||\n", - "\tthis.context.oBackingStorePixelRatio ||\n", - "\tthis.context.backingStorePixelRatio || 1;\n", - "\n", - " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", - "\n", - " var rubberband = $('');\n", - " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", - "\n", - " var pass_mouse_events = true;\n", - "\n", - " canvas_div.resizable({\n", - " start: function(event, ui) {\n", - " pass_mouse_events = false;\n", - " },\n", - " resize: function(event, ui) {\n", - " fig.request_resize(ui.size.width, ui.size.height);\n", - " },\n", - " stop: function(event, ui) {\n", - " pass_mouse_events = true;\n", - " fig.request_resize(ui.size.width, ui.size.height);\n", - " },\n", - " });\n", - "\n", - " function mouse_event_fn(event) {\n", - " if (pass_mouse_events)\n", - " return fig.mouse_event(event, event['data']);\n", - " }\n", - "\n", - " rubberband.mousedown('button_press', mouse_event_fn);\n", - " rubberband.mouseup('button_release', mouse_event_fn);\n", - " // Throttle sequential mouse events to 1 every 20ms.\n", - " rubberband.mousemove('motion_notify', mouse_event_fn);\n", - "\n", - " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", - " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", - "\n", - " canvas_div.on(\"wheel\", function (event) {\n", - " event = event.originalEvent;\n", - " event['data'] = 'scroll'\n", - " if (event.deltaY < 0) {\n", - " event.step = 1;\n", - " } else {\n", - " event.step = -1;\n", - " }\n", - " mouse_event_fn(event);\n", - " });\n", - "\n", - " canvas_div.append(canvas);\n", - " canvas_div.append(rubberband);\n", - "\n", - " this.rubberband = rubberband;\n", - " this.rubberband_canvas = rubberband[0];\n", - " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", - " this.rubberband_context.strokeStyle = \"#000000\";\n", - "\n", - " this._resize_canvas = function(width, height) {\n", - " // Keep the size of the canvas, canvas container, and rubber band\n", - " // canvas in synch.\n", - " canvas_div.css('width', width)\n", - " canvas_div.css('height', height)\n", - "\n", - " canvas.attr('width', width * mpl.ratio);\n", - " canvas.attr('height', height * mpl.ratio);\n", - " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", - "\n", - " rubberband.attr('width', width);\n", - " rubberband.attr('height', height);\n", - " }\n", - "\n", - " // Set the figure to an initial 600x600px, this will subsequently be updated\n", - " // upon first draw.\n", - " this._resize_canvas(600, 600);\n", - "\n", - " // Disable right mouse context menu.\n", - " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", - " return false;\n", - " });\n", - "\n", - " function set_focus () {\n", - " canvas.focus();\n", - " canvas_div.focus();\n", - " }\n", - "\n", - " window.setTimeout(set_focus, 100);\n", - "}\n", - "\n", - "mpl.figure.prototype._init_toolbar = function() {\n", - " var fig = this;\n", - "\n", - " var nav_element = $('
');\n", - " nav_element.attr('style', 'width: 100%');\n", - " this.root.append(nav_element);\n", - "\n", - " // Define a callback function for later on.\n", - " function toolbar_event(event) {\n", - " return fig.toolbar_button_onclick(event['data']);\n", - " }\n", - " function toolbar_mouse_event(event) {\n", - " return fig.toolbar_button_onmouseover(event['data']);\n", - " }\n", - "\n", - " for(var toolbar_ind in mpl.toolbar_items) {\n", - " var name = mpl.toolbar_items[toolbar_ind][0];\n", - " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", - " var image = mpl.toolbar_items[toolbar_ind][2];\n", - " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", - "\n", - " if (!name) {\n", - " // put a spacer in here.\n", - " continue;\n", - " }\n", - " var button = $('');\n", - " button.click(method_name, toolbar_event);\n", - " button.mouseover(tooltip, toolbar_mouse_event);\n", - " nav_element.append(button);\n", - " }\n", - "\n", - " // Add the status bar.\n", - " var status_bar = $('');\n", - " nav_element.append(status_bar);\n", - " this.message = status_bar[0];\n", - "\n", - " // Add the close button to the window.\n", - " var buttongrp = $('
');\n", - " var button = $('');\n", - " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", - " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", - " buttongrp.append(button);\n", - " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", - " titlebar.prepend(buttongrp);\n", - "}\n", - "\n", - "mpl.figure.prototype._root_extra_style = function(el){\n", - " var fig = this\n", - " el.on(\"remove\", function(){\n", - "\tfig.close_ws(fig, {});\n", - " });\n", - "}\n", - "\n", - "mpl.figure.prototype._canvas_extra_style = function(el){\n", - " // this is important to make the div 'focusable\n", - " el.attr('tabindex', 0)\n", - " // reach out to IPython and tell the keyboard manager to turn it's self\n", - " // off when our div gets focus\n", - "\n", - " // location in version 3\n", - " if (IPython.notebook.keyboard_manager) {\n", - " IPython.notebook.keyboard_manager.register_events(el);\n", - " }\n", - " else {\n", - " // location in version 2\n", - " IPython.keyboard_manager.register_events(el);\n", - " }\n", - "\n", - "}\n", - "\n", - "mpl.figure.prototype._key_event_extra = function(event, name) {\n", - " var manager = IPython.notebook.keyboard_manager;\n", - " if (!manager)\n", - " manager = IPython.keyboard_manager;\n", - "\n", - " // Check for shift+enter\n", - " if (event.shiftKey && event.which == 13) {\n", - " this.canvas_div.blur();\n", - " event.shiftKey = false;\n", - " // Send a \"J\" for go to next cell\n", - " event.which = 74;\n", - " event.keyCode = 74;\n", - " manager.command_mode();\n", - " manager.handle_keydown(event);\n", - " }\n", - "}\n", - "\n", - "mpl.figure.prototype.handle_save = function(fig, msg) {\n", - " fig.ondownload(fig, null);\n", - "}\n", - "\n", - "\n", - "mpl.find_output_cell = function(html_output) {\n", - " // Return the cell and output element which can be found *uniquely* in the notebook.\n", - " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", - " // IPython event is triggered only after the cells have been serialised, which for\n", - " // our purposes (turning an active figure into a static one), is too late.\n", - " var cells = IPython.notebook.get_cells();\n", - " var ncells = cells.length;\n", - " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", - " data = data.data;\n", - " }\n", - " if (data['text/html'] == html_output) {\n", - " return [cell, data, j];\n", - " }\n", - " }\n", - " }\n", - " }\n", - "}\n", - "\n", - "// Register the function which deals with the matplotlib target/channel.\n", - "// The kernel may be null if the page has been refreshed.\n", - "if (IPython.notebook.kernel != null) {\n", - " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", - "}\n" - ], "text/plain": [ "" ] @@ -1828,7 +286,7 @@ ] }, "metadata": {}, - "output_type": "display_data" + "output_type": "execute_result" }, { "name": "stderr", @@ -1896,8 +354,8 @@ "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, - "version_major": 2, - "version_minor": 0 + "version_major": 2.0, + "version_minor": 0.0 } } }, diff --git a/example/demo-temporal-test.ipynb b/example/demo-temporal-test.ipynb deleted file mode 100644 index 162a0f0..0000000 --- a/example/demo-temporal-test.ipynb +++ /dev/null @@ -1,940 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# demo\n", - "This is a demo for model temporal test and plot the result map and time series. Before this we trained a model using [train-lstm.py](train-lstm.py). By default the model will be saved in [here](output/CONUSv4f1/)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - " - Load packages and target SMAP observation" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loading package hydroDL\n", - "/home/kxf227/work/GitHUB/pyRnnSMAP/example/data/Subset/CONUSv4f1.csv\n", - "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/SMAP_AM.csv 0.04537510871887207\n" - ] - } - ], - "source": [ - "import os\n", - "from hydroDL.data import dbCsv\n", - "from hydroDL.post import plot, stat\n", - "from hydroDL import master\n", - "\n", - "cDir = os.getcwd()\n", - "rootDB = os.path.join(cDir, 'data')\n", - "tRange = [20160401, 20170401]\n", - "df = dbCsv.DataframeCsv(\n", - " rootDB=rootDB, subset='CONUSv4f1', tRange=tRange)\n", - "yt = df.getData(varT='SMAP_AM', doNorm=False, rmNan=False)\n", - "yt = yt.squeeze()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - " - Test the model in another year" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/home/kxf227/work/GitHUB/pyRnnSMAP/example/data/Subset/CONUSv4f1.csv\n", - "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/APCP_FORA.csv 0.044591665267944336\n", - "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/DLWRF_FORA.csv 0.052686452865600586\n", - "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/DSWRF_FORA.csv 0.050998687744140625\n", - "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/TMP_2_FORA.csv 0.051717281341552734\n", - "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/SPFH_2_FORA.csv 0.05404353141784668\n", - "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/VGRD_10_FORA.csv 0.051822662353515625\n", - "read /home/kxf227/work/GitHUB/pyRnnSMAP/example/data/CONUSv4f1/2016/UGRD_10_FORA.csv 0.0521092414855957\n" - ] - } - ], - "source": [ - "out = os.path.join(cDir, 'output', 'CONUSv4f1')\n", - "yp = master.test(\n", - " out, tRange=tRange, subset='CONUSv4f1')\n", - "yp = yp.squeeze()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - " - Calculate statistic metrices and plot the result. An interactive map will be generated, where users can click on map to show time series of observation and model predictions. " - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "application/javascript": [ - "/* Put everything inside the global mpl namespace */\n", - "window.mpl = {};\n", - "\n", - "\n", - "mpl.get_websocket_type = function() {\n", - " if (typeof(WebSocket) !== 'undefined') {\n", - " return WebSocket;\n", - " } else if (typeof(MozWebSocket) !== 'undefined') {\n", - " return MozWebSocket;\n", - " } else {\n", - " alert('Your browser does not have WebSocket support.' +\n", - " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", - " 'Firefox 4 and 5 are also supported but you ' +\n", - " 'have to enable WebSockets in about:config.');\n", - " };\n", - "}\n", - "\n", - "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", - " this.id = figure_id;\n", - "\n", - " this.ws = websocket;\n", - "\n", - " this.supports_binary = (this.ws.binaryType != undefined);\n", - "\n", - " if (!this.supports_binary) {\n", - " var warnings = document.getElementById(\"mpl-warnings\");\n", - " if (warnings) {\n", - " warnings.style.display = 'block';\n", - " warnings.textContent = (\n", - " \"This browser does not support binary websocket messages. \" +\n", - " \"Performance may be slow.\");\n", - " }\n", - " }\n", - "\n", - " this.imageObj = new Image();\n", - "\n", - " this.context = undefined;\n", - " this.message = undefined;\n", - " this.canvas = undefined;\n", - " this.rubberband_canvas = undefined;\n", - " this.rubberband_context = undefined;\n", - " this.format_dropdown = undefined;\n", - "\n", - " this.image_mode = 'full';\n", - "\n", - " this.root = $('
');\n", - " this._root_extra_style(this.root)\n", - " this.root.attr('style', 'display: inline-block');\n", - "\n", - " $(parent_element).append(this.root);\n", - "\n", - " this._init_header(this);\n", - " this._init_canvas(this);\n", - " this._init_toolbar(this);\n", - "\n", - " var fig = this;\n", - "\n", - " this.waiting = false;\n", - "\n", - " this.ws.onopen = function () {\n", - " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", - " fig.send_message(\"send_image_mode\", {});\n", - " if (mpl.ratio != 1) {\n", - " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", - " }\n", - " fig.send_message(\"refresh\", {});\n", - " }\n", - "\n", - " this.imageObj.onload = function() {\n", - " if (fig.image_mode == 'full') {\n", - " // Full images could contain transparency (where diff images\n", - " // almost always do), so we need to clear the canvas so that\n", - " // there is no ghosting.\n", - " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", - " }\n", - " fig.context.drawImage(fig.imageObj, 0, 0);\n", - " };\n", - "\n", - " this.imageObj.onunload = function() {\n", - " fig.ws.close();\n", - " }\n", - "\n", - " this.ws.onmessage = this._make_on_message_function(this);\n", - "\n", - " this.ondownload = ondownload;\n", - "}\n", - "\n", - "mpl.figure.prototype._init_header = function() {\n", - " var titlebar = $(\n", - " '
');\n", - " var titletext = $(\n", - " '
');\n", - " titlebar.append(titletext)\n", - " this.root.append(titlebar);\n", - " this.header = titletext[0];\n", - "}\n", - "\n", - "\n", - "\n", - "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", - "\n", - "}\n", - "\n", - "\n", - "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", - "\n", - "}\n", - "\n", - "mpl.figure.prototype._init_canvas = function() {\n", - " var fig = this;\n", - "\n", - " var canvas_div = $('
');\n", - "\n", - " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", - "\n", - " function canvas_keyboard_event(event) {\n", - " return fig.key_event(event, event['data']);\n", - " }\n", - "\n", - " canvas_div.keydown('key_press', canvas_keyboard_event);\n", - " canvas_div.keyup('key_release', canvas_keyboard_event);\n", - " this.canvas_div = canvas_div\n", - " this._canvas_extra_style(canvas_div)\n", - " this.root.append(canvas_div);\n", - "\n", - " var canvas = $('');\n", - " canvas.addClass('mpl-canvas');\n", - " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", - "\n", - " this.canvas = canvas[0];\n", - " this.context = canvas[0].getContext(\"2d\");\n", - "\n", - " var backingStore = this.context.backingStorePixelRatio ||\n", - "\tthis.context.webkitBackingStorePixelRatio ||\n", - "\tthis.context.mozBackingStorePixelRatio ||\n", - "\tthis.context.msBackingStorePixelRatio ||\n", - "\tthis.context.oBackingStorePixelRatio ||\n", - "\tthis.context.backingStorePixelRatio || 1;\n", - "\n", - " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", - "\n", - " var rubberband = $('');\n", - " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", - "\n", - " var pass_mouse_events = true;\n", - "\n", - " canvas_div.resizable({\n", - " start: function(event, ui) {\n", - " pass_mouse_events = false;\n", - " },\n", - " resize: function(event, ui) {\n", - " fig.request_resize(ui.size.width, ui.size.height);\n", - " },\n", - " stop: function(event, ui) {\n", - " pass_mouse_events = true;\n", - " fig.request_resize(ui.size.width, ui.size.height);\n", - " },\n", - " });\n", - "\n", - " function mouse_event_fn(event) {\n", - " if (pass_mouse_events)\n", - " return fig.mouse_event(event, event['data']);\n", - " }\n", - "\n", - " rubberband.mousedown('button_press', mouse_event_fn);\n", - " rubberband.mouseup('button_release', mouse_event_fn);\n", - " // Throttle sequential mouse events to 1 every 20ms.\n", - " rubberband.mousemove('motion_notify', mouse_event_fn);\n", - "\n", - " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", - " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", - "\n", - " canvas_div.on(\"wheel\", function (event) {\n", - " event = event.originalEvent;\n", - " event['data'] = 'scroll'\n", - " if (event.deltaY < 0) {\n", - " event.step = 1;\n", - " } else {\n", - " event.step = -1;\n", - " }\n", - " mouse_event_fn(event);\n", - " });\n", - "\n", - " canvas_div.append(canvas);\n", - " canvas_div.append(rubberband);\n", - "\n", - " this.rubberband = rubberband;\n", - " this.rubberband_canvas = rubberband[0];\n", - " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", - " this.rubberband_context.strokeStyle = \"#000000\";\n", - "\n", - " this._resize_canvas = function(width, height) {\n", - " // Keep the size of the canvas, canvas container, and rubber band\n", - " // canvas in synch.\n", - " canvas_div.css('width', width)\n", - " canvas_div.css('height', height)\n", - "\n", - " canvas.attr('width', width * mpl.ratio);\n", - " canvas.attr('height', height * mpl.ratio);\n", - " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", - "\n", - " rubberband.attr('width', width);\n", - " rubberband.attr('height', height);\n", - " }\n", - "\n", - " // Set the figure to an initial 600x600px, this will subsequently be updated\n", - " // upon first draw.\n", - " this._resize_canvas(600, 600);\n", - "\n", - " // Disable right mouse context menu.\n", - " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", - " return false;\n", - " });\n", - "\n", - " function set_focus () {\n", - " canvas.focus();\n", - " canvas_div.focus();\n", - " }\n", - "\n", - " window.setTimeout(set_focus, 100);\n", - "}\n", - "\n", - "mpl.figure.prototype._init_toolbar = function() {\n", - " var fig = this;\n", - "\n", - " var nav_element = $('
')\n", - " nav_element.attr('style', 'width: 100%');\n", - " this.root.append(nav_element);\n", - "\n", - " // Define a callback function for later on.\n", - " function toolbar_event(event) {\n", - " return fig.toolbar_button_onclick(event['data']);\n", - " }\n", - " function toolbar_mouse_event(event) {\n", - " return fig.toolbar_button_onmouseover(event['data']);\n", - " }\n", - "\n", - " for(var toolbar_ind in mpl.toolbar_items) {\n", - " var name = mpl.toolbar_items[toolbar_ind][0];\n", - " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", - " var image = mpl.toolbar_items[toolbar_ind][2];\n", - " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", - "\n", - " if (!name) {\n", - " // put a spacer in here.\n", - " continue;\n", - " }\n", - " var button = $('