From fcf5b97238f16e53e6b97f6eb64425ab9f18bb8b Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Mon, 10 Feb 2025 17:35:29 +0000 Subject: [PATCH 1/8] dataset --- docs/api/datasets.rst | 5 + docs/api/datasets/non_geo_datasets.csv | 1 + .../data/dior/Annotations/trainval/000000.xml | 1 + .../data/dior/Annotations/trainval/000001.xml | 1 + .../data/dior/Annotations/trainval/000002.xml | 1 + .../data/dior/Annotations/trainval/000003.xml | 1 + .../data/dior/Annotations/trainval/000004.xml | 1 + .../data/dior/Annotations/trainval/000005.xml | 1 + tests/data/dior/Annotations_trainval.zip | Bin 0 -> 2012 bytes tests/data/dior/Images/test/000000.jpg | Bin 0 -> 1217 bytes tests/data/dior/Images/test/000001.jpg | Bin 0 -> 1238 bytes tests/data/dior/Images/trainval/000000.jpg | Bin 0 -> 1240 bytes tests/data/dior/Images/trainval/000001.jpg | Bin 0 -> 1246 bytes tests/data/dior/Images/trainval/000002.jpg | Bin 0 -> 1244 bytes tests/data/dior/Images/trainval/000003.jpg | Bin 0 -> 1243 bytes tests/data/dior/Images/trainval/000004.jpg | Bin 0 -> 1238 bytes tests/data/dior/Images/trainval/000005.jpg | Bin 0 -> 1242 bytes tests/data/dior/Images_test.zip | Bin 0 -> 2511 bytes tests/data/dior/Images_trainval.zip | Bin 0 -> 7440 bytes tests/data/dior/data.py | 171 ++++++++ tests/data/dior/sample_df.parquet | Bin 0 -> 3186 bytes tests/datasets/test_dior.py | 107 +++++ torchgeo/datasets/dior.py | 393 ++++++++++++++++++ 23 files changed, 683 insertions(+) create mode 100644 tests/data/dior/Annotations/trainval/000000.xml create mode 100644 tests/data/dior/Annotations/trainval/000001.xml create mode 100644 tests/data/dior/Annotations/trainval/000002.xml create mode 100644 tests/data/dior/Annotations/trainval/000003.xml create mode 100644 tests/data/dior/Annotations/trainval/000004.xml create mode 100644 tests/data/dior/Annotations/trainval/000005.xml create mode 100644 tests/data/dior/Annotations_trainval.zip create mode 100644 tests/data/dior/Images/test/000000.jpg create mode 100644 tests/data/dior/Images/test/000001.jpg create mode 100644 tests/data/dior/Images/trainval/000000.jpg create mode 100644 tests/data/dior/Images/trainval/000001.jpg create mode 100644 tests/data/dior/Images/trainval/000002.jpg create mode 100644 tests/data/dior/Images/trainval/000003.jpg create mode 100644 tests/data/dior/Images/trainval/000004.jpg create mode 100644 tests/data/dior/Images/trainval/000005.jpg create mode 100644 tests/data/dior/Images_test.zip create mode 100644 tests/data/dior/Images_trainval.zip create mode 100644 tests/data/dior/data.py create mode 100644 tests/data/dior/sample_df.parquet create mode 100644 tests/datasets/test_dior.py create mode 100644 torchgeo/datasets/dior.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index d01a91dfe70..746b0151e4c 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -273,6 +273,11 @@ DFC2022 .. autoclass:: DFC2022 +DIOR +^^^^ + +.. autoclass:: DIOR + Digital Typhoon ^^^^^^^^^^^^^^^ diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv index 1defcb032bd..c064a3da1b4 100644 --- a/docs/api/datasets/non_geo_datasets.csv +++ b/docs/api/datasets/non_geo_datasets.csv @@ -13,6 +13,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `Kenya Crop Type`_,S,Sentinel-2,"CC-BY-SA-4.0","4,688",7,"3,035x2,016",10,MSI `DeepGlobe Land Cover`_,S,DigitalGlobe +Vivid,-,803,7,"2,448x2,448",0.5,RGB `DFC2022`_,S,Aerial,"CC-BY-4.0","3,981",15,"2,000x2,000",0.5,RGB +`DIOR`_,OD,Aerial,"CC-BY-SA","23,463",20,"800x800",0.5,RGB `Digital Typhoon`_,"C, R",Himawari,"CC-BY-4.0","189,364",8,512,5000,Infrared `ETCI2021 Flood Detection`_,S,Sentinel-1,-,"66,810",2,256x256,5--20,SAR `EuroSAT`_,C,Sentinel-2,"MIT","27,000",10,64x64,10,MSI diff --git a/tests/data/dior/Annotations/trainval/000000.xml b/tests/data/dior/Annotations/trainval/000000.xml new file mode 100644 index 00000000000..9cf3dcb4484 --- /dev/null +++ b/tests/data/dior/Annotations/trainval/000000.xml @@ -0,0 +1 @@ +000000.jpg32323stadium5111721stadium1322222airplane901916 \ No newline at end of file diff --git a/tests/data/dior/Annotations/trainval/000001.xml b/tests/data/dior/Annotations/trainval/000001.xml new file mode 100644 index 00000000000..e4a7f1f133a --- /dev/null +++ b/tests/data/dior/Annotations/trainval/000001.xml @@ -0,0 +1 @@ +000001.jpg32323baseballfield502320basketballcourt992428 \ No newline at end of file diff --git a/tests/data/dior/Annotations/trainval/000002.xml b/tests/data/dior/Annotations/trainval/000002.xml new file mode 100644 index 00000000000..8f6a8784a50 --- /dev/null +++ b/tests/data/dior/Annotations/trainval/000002.xml @@ -0,0 +1 @@ +000002.jpg32323expresswayservicearea151524harbor812127chimney182626 \ No newline at end of file diff --git a/tests/data/dior/Annotations/trainval/000003.xml b/tests/data/dior/Annotations/trainval/000003.xml new file mode 100644 index 00000000000..012639937ee --- /dev/null +++ b/tests/data/dior/Annotations/trainval/000003.xml @@ -0,0 +1 @@ +000003.jpg32323expresswayservicearea522316bridge1172221 \ No newline at end of file diff --git a/tests/data/dior/Annotations/trainval/000004.xml b/tests/data/dior/Annotations/trainval/000004.xml new file mode 100644 index 00000000000..cb58b98b9d0 --- /dev/null +++ b/tests/data/dior/Annotations/trainval/000004.xml @@ -0,0 +1 @@ +000004.jpg32323baseballfield11142025bridge462123basketballcourt7121931 \ No newline at end of file diff --git a/tests/data/dior/Annotations/trainval/000005.xml b/tests/data/dior/Annotations/trainval/000005.xml new file mode 100644 index 00000000000..97ea68e8c1d --- /dev/null +++ b/tests/data/dior/Annotations/trainval/000005.xml @@ -0,0 +1 @@ +000005.jpg32323expresswaytollstation1073119vehicle15122529 \ No newline at end of file diff --git a/tests/data/dior/Annotations_trainval.zip b/tests/data/dior/Annotations_trainval.zip new file mode 100644 index 0000000000000000000000000000000000000000..ed45b6be8d01da57e770be3ce919e9adfeb1190b GIT binary patch literal 2012 zcmWIWW@Zs#0D*U%UQu8Mln@2dj(K_cC5a`O`FX|qB}Ivud1Z+?`T7DqR?~npd>-`g4-Lq7A-(TX)i#3aoIrj4M>7$aKCX)7i z?ei|_v#4axRX=}DYtM%SonE)Mmv4Oy?*5?p=*;;A<=6HFPjxxLXMH&LdGOQRz}cGL z6jGWetlzcF{JFsL!p-Vuw+kxr=da1Tv#3Dw<|<~VTOS2OIa*SjJi3LJ&tMb2w7_WT zM~>G#5stGP7io2DV(^M&dg`X|u~h2BtocWFF}yo1evLIi=0Eep*Z<>Cefs8DNW~JM z57Pw} zT((k854v>g+Jvhdk_+dA)d}wCIo7G{w5sBHjKUHr+npVol8n1Die|^;CB`ypznxvR z?7*F4@%j(dx8Eyxt^0k$(|P}jA3omVkLtm|31KNafgZd@(1S*hsC>xRq#)9Ie}(jv zXZ_aqSih=mVPV(Xx}0AnI%n(2{$l>`2mT&&$n24Qx?-(h#np17P3na`-STdmeZcT#gzq5{;G)kzgs9hsZYhp z>FTnzZ*>aZZ58B6T)V?XSKykIktKK75rY{5Y)_xhIHsGUoUk~pL^U?|u7AwbhY@X$ zb&kiDJ(5~l?R>iLd8p<3GnG&4VdpC7x<>j%&{>=}61T7Sk@vE=zUzVPZHV{aX9Q zLJ?W_8!~+z_I%)gttcB?k#7u6b&x$I$iaPVe`?n=sySd zy|;T;x4ZF>#Os#J*9z8@`n&NQ*<5hS+1!j{zS6m0FW>JDTldtpZ-Z8jr}^6MsoPy# zlPj7xaJ6i`da6%L{&-=~85P(2*Jo9=Zcthhcu7xYx})?t2`7;s9ot+w4lPf5SSc2y zUnBL#aLK_#W-MlLr{-?S$`FrG+Y{WCthw%oxcW2gof|Sz^C#baI`6XWXO=U1oI4Ws zAN=rjCR*9#_<7#Bl|XN25%jhx$lJ9Ovbhc`2(;v1S=$!C)m?6Q(oCwS%}wg^oVnWO7#WEXkA+TY3HYH z78fQ}33NGXgvXi(%0IU0J5;p1^zM^)a}BdAnVlR<%U;9?NTlVmS}UhCS6||Gx|gic zoDspbxHB=N%s12OkRdZ)PHX*yOM4gIKd_C-C{D04TJWd5QqA9Yzy!?5B*Kil*amtT z2m~13I)Z2%g*eOz?8P`pBM2;MGzOAbjX)IY2=_n?LoX@8CV>r;1;zw!!@%Vxx?$+0 z62h>5cnt%WQRs%DXLf{PUzqT@44n1R4MWdW2*YkL<1-AL$IuN!Ppb&S&a>b%44iTU Tyjj^m8rXqQ5a^MutRNl$sBrII literal 0 HcmV?d00001 diff --git a/tests/data/dior/Images/test/000000.jpg b/tests/data/dior/Images/test/000000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..55cd3f2144b2a670c1573350806a94298dc90abe GIT binary patch literal 1217 zcmex=^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#L&};>w)>WDbcpUZyV0*_Y&^;7q1kU48K&FQ zpS}-_I+?(_Ylp68d0DH3>=_X&x7i2NRJ^z6v0ZL(JfLgj%2Tw)DyeqPox7c0r@xDK z7XCg|ZpPRnP$8L<%(_`$6|&OKy6;{H)q_-7mN97hm4)?|WAIwEJb%S+_J28jkg_Ph_4py}-WGfop!D zB|F0$0oSyoS01Mv9=&>ew$gd3(6)ObyCz-@JjAo=8#@dCOkt}zF5X7c(d!gq{=Kl& zx^HoZ@A0R;1q^n1ZEI)yZ?^GV7IV|PuBlKYF-%hQbn`~D9(Lt2e!sOZu9k&N$`4<+ zJNtR=*4N)FMHPNtk&4Z)4cm2p{oVD8onl-LzTNKh-tMsEw3h~5PyKwBImL)(?(Y$P zR~WQpbL{PySnhNV2h}4hpYFLGz@EckWa+cw_$g22OjDNIK{t;3qNB9GU^lx`8W2qJYE}d8=reY_`f9KWF36mnL@&XL!T=Xgk-mbsv g?z?Yg_di|#@VdHda(Q>{{Izc{Z_2*>r2hX+09RA-6951J literal 0 HcmV?d00001 diff --git a/tests/data/dior/Images/test/000001.jpg b/tests/data/dior/Images/test/000001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5c9f08916708f41ff161bc7837d5321e56b85043 GIT binary patch literal 1238 zcmex=^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#L*|C2RH2(k@&XFhY|ws_6ml#i{n?!*%@f4m zJxGz+c0Q#jYR$9>cMkoI3h%X98nnPz_V;w1)z6-M@t=~u)hye(x+iY?%x}484uyDZ z+8wd|Ju^?f^^?gAjZtZ;dRu-TqtCZ7*Yxn9hZ1>go&)A%dT4 z_PRanQAl{1-er9wdG4+3ClwDT-tk!DwJc;oZ`Z1(k9PNt7cyH!wQ1geGqF2jR;2Hq zzzzD_vso32qIYec;AWS~l5SzHH0f$jaipV6uE`N+mH1s9qNk_aV0!*eBY*lchV-T# z>Tfk)ZF&9u*58sZuX8yjKlR=6>-FX5r8lFW++Px9dSidZIx$X#28S@wh(Cvpa>zQU z^zJxURn@bC_iFCarb?c-x_ckpo|)9TF7^o_M z*jsi#dN-%eQ+GP1qsAz2pzvqG>I6}C(Y5J<70Tj1+XEeyZa?EVa`fqqnQneX3m&U9 zd2l(z6m0SHbuYb~@AAd(m-_YED34WpzC8J|rDWIj)XlM9YWKVqTjTICc8b`eTNV!+ zbmzSB=+j#jTx=xRde6=3PqVxM%W-3lv_5P1?Pk~fnc0GzwS+$zZR)*mKCL|Ww-mE8VR`tmK(O7YSF(`WVPW6LrkwTUy+zTGd zygR6(kre&x!NeU2+AB3po=xhVw)*RTh_QQT?fRdg-U~pDJ(e-M;y%{{Kw? DN>>F( literal 0 HcmV?d00001 diff --git a/tests/data/dior/Images/trainval/000000.jpg b/tests/data/dior/Images/trainval/000000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..04d70c662713667f5ca45d7288fbcf0ab862ae1f GIT binary patch literal 1240 zcmex=^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#L+8^jT|3e3TXsyl>pqibIhX|rJa*7j2!@8(%tGumiq>FLbF!}K5}&qma`yHL-# zXR7y$dAnZM{rVagws%%onAWCAdnZ5Jy!UtXdH=&@yC-luy$W{Esyno=>PSkrkieNW z%lRZ$BuKjPaq!$YE;&!F>rrrdrq!XxOm~$ZpKh5JP&Pr(^Vaf(sfLdwUQLWTSuA@f z^Rv<|m4{D$PIHMA;K|i17vi2jTl3R`S4NsH&8}0$o=%K2zpFmO$>?c>fbgUumWr5n zx1#edo>-TvGNmLldRzDYslSTtcROCGdG$8y-mRDQho8)STs!wfUfio&&#c2MzDb(j zaXERCz3^v*g=$Fdp>!P%?VTQnO@6LhI=gJ9X8K+o=CeW@(vLk_*!4+5fo1Z^;|x8V z`MGP~zCSa4Ug>IIrByqpY_5!C-1kx1E%5W4`4dE$nbw?5Q{6f5?gImzt_`hyF_sE> zmNB{Niwh1n85Q03TCH*Vajeynzh?g#%vXtIYwVSs{hy(&`s=0dQMy~AuXj}+-L%=7 zdN80;v0aj3;+Ck?Ip&3{-p;RDIk{vZ6Q}>oXH8{!CBc7G zIQkjX4c)bGy4oJ?+R^7Ex^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#Ls#UM7i-_Od{jNNSu3~b^v1_&bs-y%2~ODd zc%p2YUs zQBl=M_MF4v#J0C5cWe{SmYK99>AK346b3gXu|i`~7+Mm+sYlx{~L!%J%JI(=Nq0Y`EQ;?XP_%Sp400 zF)crl9)GEqTl}e;(s|S73Nfba9{djxQ&l?&$3)IIVwco@L?W zC#wXU*-o>jIo)||dH4gPo57@>(CL#e?|QW*FfQv?#*xST8BeCoU7ss=x^#JTto_v_ zv2Q$=SIvC8nA=#PAk2=d_h5HIE0^UtpN5+KnybwxP2R;;v2f;*Nk>IyGO;JA+~I0C zx@`hy&ZJWXjt5rcuFYdzbZgCYLoeM6T9QobVrTRy)=!Pjh&|E}+Gu%zbsBHUp>Otv z>rcm=w6YB9R2Ds$)OC0g-x-$MrzTG-5Dk^9D|x?q>#qAxrf$FNe@1D&wD0bvrQb8t M^FLR;zyJRx0P<4+d;kCd literal 0 HcmV?d00001 diff --git a/tests/data/dior/Images/trainval/000002.jpg b/tests/data/dior/Images/trainval/000002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a2ee225920e65e7c5c2d7b4153e857178639508a GIT binary patch literal 1244 zcmex=^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#!_-T++$QWjy2R=3Ebo!qU=tbFDLo5PdzCXa(LxvRx4-#xSWmC*g;$%PX+ z%Lc&l1gBRZoKVOjr`TWxGDpE*&{PbQu>ei~grFUy?Z+)}nm4da9`;MKH9_)0vwL(-mcEKc` znPFUen%EwlzPEU<*+&&6O_iygVomu5IaArhr`T7T=iALxo@2Otk5tXIPr4W9$*yU< zem*U?hmmzk;tautN|p_`JP+LFVfyvlcXHljBj3rbI*ht=_D`3*ZWn;*zsMS4h#qLjkn3}u3w&dEsZ%aM?tX|Q($aZ1TLc7yX*1b$l z6g;8B8KyF4TJq#3wM+G8U zZ|;aswPZ+Umh!8)nyVdk{mu9Cm-*or-yPM93<-Fx{o>c{)QgpOUoTtjGCS%+!RI?B z1-s6qcRj2**_l?cDX>H0j$}m2oN3P|Z&Nf|n5YxzmOIZ_Y4;ruhl!s!Bh3@L=6Ogr z%RDYrD6T$dwld!)=)8>KaY?(R1ue@S9qBk{mponCA<1ZOk$L)(<$+h4b$=h z#_>-J6i+|rmSDEoF73%OG1AE>a8lQv(red0z1_RzzV9Tp@+Vuabw8VRIpxx}?Oy*G KY}f4ne-i+;1P1v4 literal 0 HcmV?d00001 diff --git a/tests/data/dior/Images/trainval/000003.jpg b/tests/data/dior/Images/trainval/000003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bdef43ad0852a564e5b32010783d1fb00cb11a89 GIT binary patch literal 1243 zcmex=^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#L!0&Lh54U-KJ8?EwC(c~p7YyP%&wlF5$SO3 zjm*Z-pN`8-wsV(!4~lgZUR2q!j{8u+X3prRK@TQK2)Vl)e3F`*#mN$`b>i}2)i*O5 z^BA~}s9n7ilHc*Z(vsbh_nhIZrqFH9-&~V6>(5h~rLWmOXHvCx>irvl66$&^3CYa_uuVZyY9=Nd^a+@h=L`Jw0O zp7pbolo>;&{AUQOaBz4cRKTyc`lN}#iPbqD-LpkJcmj{jE+{yDCVB4N^7N++CyXam zs%?_0pQ|!ma81$HLo@!Ji}YHlocDL_1n;Tm>U;Y4PEfw0vuw-8-nET?Pgh1KwW_7{ zM*Mx6!|+Ptu_Vjpl?;VmMm*X4`<84hgQy|y&!<^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#gU#)irfECoUAm=RBw{$}#kwynb7pQ5S;w`R z!Fb-@%Em>4(>`(es5hLs>~a3tdkc=;iFeqOw{F;P@WSz|c!bWO2$R+7(;}01Oy&v6 z$(>YlXEpC)l>&nuI#OHKJrR$Zve7*=wNPiB;`Y@1uL(M{{T`Sb=LKG2@H5=Nb9Lpm z#~N=_%l|VlM}L36^^4o8IUze2yj@hAbG%M#1 zF^>963LzB}7g|SlCn)t*N;2-0cE2Za=Gw(m>v^h5^0N!xem=wzI{SI*xpzJL%U9;! zc&;$_rsF;Tx#ro~`LnOD{k>W9o|V_%ulLS>4D-9bef^YQJGaeybNrdN=#}NCIy?d- z44XI~a;j$uH#qQADobZ;UjA~|sX{+TWpm4BukLbziB4}*f6ibk-Yu^$@nk|>?n%SF zMz@_=x^->{bY6OWqWHAHswEGa=j7e;ymh4acG>BQ4L9PBm#A!w-sE-wDv+vH^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#L()x^S=G1h>7+j^xov-@Sb6H5a}um-9v62U zYPd5YLRf9SjDTOa=_O|YUJjvSLh;6@ZH~)X)qOq2^iH)_qEUv){Yq*{ljOOZg?5*Y z?%5++M-desynk z_F9u$-}cX&8|mv+GUvAcvbnXwuXpd-TwA4V<#A@UbgXa3!U>zjUxwl8& zCAU4F!1s89N{4&ua*s`iYm~P#J&n=Y?^Kl+6vfh>SmzQVOs&3R- zZF_+&P-;S}c-FQrVF?~D539>0nE$q_oTkZjF8oBmB!h$#Pd#{?nZjA8vwS~iE%mbJ z<>}JYr{&+{Zrgv}I&bF6V_YY1?#|3L{aLzirCsmsKr!izce-oV)!mfQ6VndS-e#V_ z)Vn5J-Fp9yN@bnBEZ@V9#Tw~e7d)PA*xwJnN+^CpTy5zF#)$*4=vd)%6vAo;yC@JH7YIwq4)KwmD*?Ua&d97k1=gR zMxHt!Ckbwfb>`_kRVHR(af|(oUBdMA-*(Revn!M*Kb@3uw`nS)!(qE}p*_Mq%)3?o z95z&$R(7#_N6TxYww+DSl%IL(9Sb!U?pgKW==Yi()^6{otohVAIVdq<)7`TL7rQsF zaFGnHOWxU6xVvdboZ0;8mMYxSbeDaJ{uy;TWXr|q+`GNW8Sb-ecfFWgeSgvWr?2|| G-vj^)oBZMc literal 0 HcmV?d00001 diff --git a/tests/data/dior/Images_test.zip b/tests/data/dior/Images_test.zip new file mode 100644 index 0000000000000000000000000000000000000000..6a8fd77df479a72123e6c2108e120e4860ae5c4b GIT binary patch literal 2511 zcmeH}`BM|t7ROm5M#L(D-~g6Vg6vR)u(xbsEkVK-kUarpl{FG7RM}(^WG_n+#g-_m z1|i7O0D=KQ*4AJ!MA-u*v`7dL2@kPv`gC6ZfH(8boqNu?Gv9OWZ=d-(*$WDt=KBz= zEIZG`=C34rP;DdpgORbCamd&>O{c$sdX3{p;v;rDRcXR z0Des{6byhs0079Bi0zchp7QIdw{Br>m84=Yn4+T89P6~Q%*sk@XPcjVE#2&V9Np|7 z5Hl;Nx0$Uq90GZI2}bT7>+7@hp-sH(9e|8Z^@$)v6qJ;dPJr`;G@H^2s&lJP4!+sv z2bhTggaN`q5JdpU3=(AaY~NaH{yGbR6cPlS_d^5(`6C0D-ptPeRHjtAyZZ))hx>Ym zpFi&vC~gpbuJcfY08~C}NH}@~L{OE;meO^+G;nwA@t=1>#`;gFJhQd>zUorI?5RTM z+R&A=6`r7)phd}2w-JrTNR}Mdr$_#-|J#!V7YCIeTxLj@!}}F~K`uy?xV}(th+xWM zyo*54a>Iz--125gbss|9s|^WN_XuhyaJF6P?wK0P1)IHa`SPd?|1Mvj-bmaw=)d6dgV6{N~siF;{ew$$-+2Krjw|v=$5Z=FWOo z^R3*QCDUuEmsMt_-%V+N_2tHNV-}P>14~?tmPTlfO_Ovhqhl%ORKS62C-2`Lwhf}G zJyz>iHlwK2sL#=KInuyvu3S~Ii5oV!sGZSm^`Nk>F&Gp4_d0v@I2k>q^bNGq$IqA1txFm!#Q5j^80I;lnhz7p_c1Z` zVi>!IOq&#Vi#Gw}X5s8xnDmsIhri35iBe!J#{;$ZeLO;G#I~oLZ5Uku(9zl2(!2+r zzN(S)gyBtbIbSFRQbxBY-h-d$h><4Isb}1W%s~s8P1Be?W7O*EdPz0#+a{a-MM?&< zoy249_B(289HB>|&`~Il7^;10BR<)5s4BoxhrB_|#05!s^4 zF^}+dR7_w@9zW(WH{9UttcccNG^5O9>vRrP?YO6(8aJtZWF;8->o(Mu9$FgVO%o?xTO_{L$uu|C zW9Ciqu;ohh-9bdgvjcm>YZQ@^ce&ljR8PjEUUW`rIsEZ>pew8Ck{1 zeu(!P%_i0oLZ8<)*X6Ki?Q9e^$~%dso;^*3l)jGKZ(y{RY1=p~^rRJt!mkC!9Jj5a>`d|XtMz|+N)ql6!(1n=JCex9%GIl>7uHGTQ?lBs6Da5# zz#j8eJw-pA z)zz-Rdy;ffmtl&dUbL?K$Wez-?=wj^2QOy=Q;$%9dE(?L!M3N)>sS+GgB0ZB=r#3@ zb6u6Nr3Xf`lNSSv!A{%Fe(^WXcPvxmdVt7i9P@gf9T68TW3apV%mGZQ7S4CAFm96x z0Jd)Lw3qpepmt_C_Iv1vVsfpCFZPF5&V*I;%Hot}OVUhsL-e+kk=5F`dQRr?_Sl)L zn~VmahTf_0uB9M09zOHqLYua=C$5E?&>Y4a4mTU|?8W8kv1hW%x9=Xr^&8TY=wgW8zb0XOfJ zJVAp+NW@3&l^_qPALR?k`IoPE}M_p|n|XYcoYZ&-dM zc4XHd7bn@m>C@mpuWvto&HX(f-XWTy!5+|na1TGt8~=&k_bGZHHUn)ewrkhK$LG}l zh`#XY(zqK45yj8#d?m4469588n3_sRn11i|K0dE2C*haWRCH?QFK9Fxf%qlGEHUdr zUY^;_>)CEs?JZnw>@7@9FJA+JFPob=nwq|Ju_biA>h8YkN*^A7(Pv7Y=-%hGPfkfm z=^!9oOtUKS>9fbr4)xFM?3TDJB_S>$?qjMTp?KNU=W_3k*@^iCj+eLCR}#wZUYA1? z>)-u$=T*-^m0t5})+iUxH_{n5OYgEy{f9gM9 z0$o0ty`Zy4F5U4-7bc$v&8kE%3DgaX=waCYx*!O) zxFEF%6n(k?K2)D5&Q%L0x8;Vj#vl$D=yyW-maIz_#eR~dgGMK!o+;j6M0<12Epwo! zfK%m4zjQz%6037R#Ne}ddKQmUr|X~1W&EN|_;<*Z&X}mD4P3Pn7n^#3W~i!WJpH|XWWbp$RPi`pxfwXZdUd-9H;@i$vO1ug_x)0o@5YN;FA36B zil%5aq|vV%j?2h_mr5qZEmyyOJI~5I5JbjSY(TqI<*c)OC?%dCrQV!rJM662`o$|h zMBY4Pd2I;i{5tA5sNB*n=Vr!9d((eQJcoUsbX z`0UYXc_OzN1?+cyuk+%HuDtL@tJDL5P5Mhk6N)+-tY@neTP4MF%>QPYAgjK*DWBh| zy)hq@Bun}*)c~8u<``FnHdGClqHXmYOtm>8W%jZ6QKS$1pAA|0l8Q<=s0d*~f(<2U zy#0jbec)iW3%^JFLEhs>0q}}({Vo0RyZvr;JCe+@tF^j<0n#@4mHiffSg1@_e@(%a zrUy*=5rd!P^?0XI3eC-mea(}yjf_s*D4EdczYt`3i*<(g!Vb3RP=Ru(eUdy9&#Wt| zZOfQ3Q8}|E$x7l(0>J+1;q0XlY4Ew`@2ky&SsyZ?!l|w7)aXPhQ}ElPEAV2a<2BVn zWRI>e+XT0D<*xm&$u*$BP|tn3#0+CQEL(;Tr;6ZJu^j;>S~O2wcg>0>7f#!%AjN6F z+p*sD|E+UMt{}whqj4C2WgV@5tn;_Dj%x5=N(0x)xD-Pz&vXuro|~(0eUeHLQ}TlY zg`GU_Gj*6s{30Vx53fa(+Lg$Wq8M2KVzEy+z?Q&@9h0dQ$;J0mb;R+BOH5jl^D0 zr|ty+^hk#{`K2&M38~e)=EKM>;U$u))CFXe4Qiy%hl0)P#VDpfynt~~^6)0t5#Ts5;YwvJD84QHQzGc=+P@q#`6O>wK9Jct+( zULRe{y4tCpZcIJjiv%)rP14`tRWAUQTAM*|LKXrPu`W1*8t>30*9NX3?DmlSiXBH# z9AG2xL&9sK`%Tc&dVf{~ep32G7GHI)@Amd%r4eQ&Inxl_3N#qszMC-tzzml8w1bEI z%Wr(^Ll}cs8|G&YuAx9b661(tOoSrWv+QSmUSDCsIh~vd#4S#s?ispXf(~T3WDnsh z)nE#@-Mintug=zvLI+wsCTl2vbRs zwPG0nlev#98xb1feT?YRof>i=lq7SlWI!{~{UXEpB;y<@yP0 z&UwZ=NHVjcRH+NDsFqBx1!JLEX<=DHou=n$#SVAkqN_0Peh|EWWEMgcFk;X+ysL=p z)&lk|X!8NnG2vmpV~k;(Zc|e`b^bF2jnmlx&(#YB`-DcXk zjj~$sD5Gd#)TX8=<`-0n-y7RiRU(h=qe9klX*VlnPAbt;cwUr?t|nL{?PZ^VLQsB~whu8K-uV_wZl+UDq}M+GB- zCUTw)Io#xujAC%VUEgYgkVi*yW4No@Cy(qbX!uNgOIV!0dxPHEz+K^?V}VN80QgPm z8zq&v203E051GuWJitrp6e*O$k25}Dz8OF(uqgJiTCXP&; z4KmX{=S`|#L^PexRks@5A3BJJl-G;W9=z-Zp`J36!8~ALp116Al_L&V9p@a_;ZWD! zx6Os%;d@B$^t4b^oh3$t{lu%jsSg<2zt*WrZ0$(rV$LIj=H3^JKPtwCSYDv{WeL1; zPAq#J(J%1@EjO%S860}>bVP>J211?EiY2z9f}ynVsvXHjU%^ht{x9TXkwS03`cXWz zzmkv6Kl1rI$|s)~&42NMY>W-BHJ%h~>?d!qReT)CxLN9wHC;R94n&+;d+MS{OK7e?R%mP6lUVE?J zNYQYS8T+%q^M#M*?Ea^D(V7M*6H#VIq8rTFW?~?7^)OBgpCl#VnEx3aaOLvt}ul3%|;Y4C>6`l5&{ zsH8`Gjpaw&bJ$7F9LvP#Hbm{TJ`YKiKOXRUZ2iM4UMP9f(PH57S=()$L_2+b;31V6 z2j|$~t25m!R62)xP8OIz@Hu|uj0plL42asaO*-6*P#M!RA0yl4XfzkyS;;@SL%J4S zaf>9vh2Im7hVYHH;7ZDFuN@n`Z*wPmJBn2c5Kd#m(Q@0yM2|>?cRy!%^DL6CXcl0e zZ#cL2cyB?nWkrC%GD_i$5zf`P#e`4HqD9dV9JBYz6E>9Z4d9=4`DDoZ2}D^ z#@s!K1kI<~B{yMICfihTTSX-Wrf%>~dg*pw84x zv|h4?ENtlwCf4B%%MH_X0{#2&A;wwP?pPZ06q7?So!7)|q8=@_Cn^{A$;SliBjYJ6 zYdt%XdmGuV&Zqdn%j0FvNT7O-8#vl%GG@8z@$v|1;J@9IiMW2zl@T~KUu%;;&7i$H&86)#TU zbQ`YrI|Eh$k+YHZz$=S468rD+p-=k_6?uvC^ERa|b-58W771NaG9~IOn1^Q8)Fn)g z)9-aQLGZiV6frb4?%+@=Z%@^Q_a_5pT3Ul?w^N%+7Isc%7j$$?Z}#|)p`M@*Ff>Yu zEv~0Uc#TyCj000tW5o3D7G^v<=-WZYcj^CFq#Lwyx#FvD0JkC){Y0emh6IfoiZy2R6GB2_Y<*Xe4;x2)XC#}?GK zc-J{Zv~*rAZYnW>TsX=_QnDho^EHHpv#td+k{aCrAZDP6H~uNc$$cziFEL~-Ac|oy z49Bg8_9UGSB3RyPG==dahjW4b9&c&~os~kaMV+ysH&(t4bbL0Q@w^yBVQECeU$3u; zgwzO#Z^a|Hb}U68OsS%83IetSnFf@{5Vs?2XP40&Y?uc-2T(C`82!EiKg545DGE!O zdY`C*q_Op{;y5pjH#=Hj`qMyL)WRr&yh*Tkrfb{MtqYDHN zh>T0=8!uKWciM2V??iZ5Tzt)c=Ud3UD(K~p%9;Et?dblaoxh`Y&Mp@Xl(U?Lp~KkB zsYVmqav+~{X}`U0dWxZx?o`gVXUp~6OgpP+>lt;sG#|Xu%8#5I(IYIBv)W~EM;AM( zzE*xYAE(ezXudW@yn?ed6RsXoUyA{5^-r^o^vecq*#XG&r2?pihs*fz>t`GHl)7dv zBEo{ZWC-(V=WQ3oI$E0y;1&H9Fwr70b)B!06GDrpPC3m|a1s8M6*n8#^y{1q#+z=> zCRV^P<8FrGPYi`qG&ybt4Fj5i11~>+X%yB%D0-(z70nc}ow)JKNce%o`k!JLjHp`9*TE zERbY1TR!60j;%vRQ+aElP}ayAH>ini0z(5r7PKl6fCvaPcUz6#G+ZeR|1CST*O0Ee z!U`BzN`^kT1JePyv%rTd1Mu-?&udApNUjILGc~m8fGoK&D6URxQcd6mxf4<%HBy?s zcvUn3QooVb(4vo99)@yEgnS#tcUFfuCUf4VK_(v_2>WsYctSHyO{ZUwi?{> zN_JM3Nmq%NI$dg|ikrX}EHwN?BNagg$UuJD{Zj{o<(q-)aH5oNlhAfNhJ`xwYj*WA zhk7rCp6AefJ!xElyQ(@xua_hFZHHPO)(&jX$6Raoo+DrVc{l9ejj?Zcu`mBI`tRP@|IPm2-LZeq-nFYRLF0eS{@g43 zyo!HwzN~kD<`jNZ;+L!V#Q9YH{*v=$;p#I-ZvUTfK9#Y) None: + """Create random RGB image.""" + img = np.random.randint(0, 255, (SIZE, SIZE, 3), dtype=np.uint8) + Image.fromarray(img).save(path) + + +def create_annotation(path: str, image_name: str) -> None: + """Create PASCAL VOC annotation file.""" + root = ET.Element('annotation') + + ET.SubElement(root, 'filename').text = image_name + + size = ET.SubElement(root, 'size') + ET.SubElement(size, 'width').text = str(SIZE) + ET.SubElement(size, 'height').text = str(SIZE) + ET.SubElement(size, 'depth').text = '3' + + # Add 1-3 random objects + for _ in range(np.random.randint(1, 4)): + obj = ET.SubElement(root, 'object') + ET.SubElement(obj, 'name').text = np.random.choice(CLASSES) + + # Create random box coordinates + x1 = np.random.randint(0, SIZE // 2) + y1 = np.random.randint(0, SIZE // 2) + x2 = np.random.randint(x1 + SIZE // 4, SIZE) + y2 = np.random.randint(y1 + SIZE // 4, SIZE) + + bbox = ET.SubElement(obj, 'bndbox') + ET.SubElement(bbox, 'xmin').text = str(x1) + ET.SubElement(bbox, 'ymin').text = str(y1) + ET.SubElement(bbox, 'xmax').text = str(x2) + ET.SubElement(bbox, 'ymax').text = str(y2) + + tree = ET.ElementTree(root) + tree.write(path) + + +def create_dataset(): + """Create dummy DIOR dataset.""" + root = os.getcwd() + + img_dir = os.path.join(root, 'Images') + ann_dir = os.path.join(root, 'Annotations') + + if os.path.exists(img_dir): + shutil.rmtree(img_dir) + if os.path.exists(ann_dir): + shutil.rmtree(ann_dir) + + # Create directories + os.makedirs(img_dir, exist_ok=True) + os.makedirs(ann_dir, exist_ok=True) + + for split in ['trainval', 'test']: + os.makedirs(os.path.join(img_dir, split), exist_ok=True) + if split == 'trainval': + os.makedirs(os.path.join(ann_dir, split), exist_ok=True) + + samples = [] + + # Create trainval data + for idx in range(6): + img_name = f'{idx:06d}.jpg' + ann_name = f'{idx:06d}.xml' + + # Create files + create_image(os.path.join(root, 'Images', 'trainval', img_name)) + create_annotation( + os.path.join(root, 'Annotations', 'trainval', ann_name), img_name + ) + + # Add to samples + split = 'train' if idx < 4 else 'val' + samples.append( + { + 'image_path': os.path.join('Images', 'trainval', img_name), + 'label_path': os.path.join('Annotations', 'trainval', ann_name), + 'split': split, + } + ) + + # Create test data (2 samples) + for idx in range(2): + img_name = f'{idx:06d}.jpg' + create_image(os.path.join(root, 'Images', 'test', img_name)) + samples.append( + { + 'image_path': os.path.join('Images', 'test', img_name), + 'label_path': None, # No annotations for test + 'split': 'test', + } + ) + + df = pd.DataFrame(samples) + df.to_parquet(os.path.join('sample_df.parquet')) + + for dirname in ['Images', 'Annotations']: + archive_name = f'{dirname}_trainval.zip' + archive_path = os.path.join(root, archive_name) + + shutil.make_archive( + archive_path.split('.')[0], + 'zip', + os.path.join(root, dirname, '..'), + os.path.join(dirname, 'trainval'), + ) + + with open(archive_path, 'rb') as archive_file: + md5 = hashlib.md5(archive_file.read()).hexdigest() + print(f'{archive_name}: {md5}') + + archive_name = 'Images_test.zip' + archive_path = os.path.join(root, archive_name) + + shutil.make_archive( + archive_path.split('.')[0], + 'zip', + os.path.join(root, 'Images', '..'), + os.path.join('Images', 'test'), + ) + + with open(archive_path, 'rb') as archive_file: + md5 = hashlib.md5(archive_file.read()).hexdigest() + print(f'{archive_name}: {md5}') + + +if __name__ == '__main__': + create_dataset() diff --git a/tests/data/dior/sample_df.parquet b/tests/data/dior/sample_df.parquet new file mode 100644 index 0000000000000000000000000000000000000000..1f99eba0378bf72ebd94a32030b8628059f52392 GIT binary patch literal 3186 zcmc&%J8#=o6uy!yHE}!$MJPyylNf9xWN>6r4<`yz6qFLlRAt9eVo_3C7<`$cOj3+U zy%KooP-N&2=+L1BTA*Wxju|=>os0a24nfhOL#IqVht$(HTA(o+!o0lqoO{3X-SdLH z$5t6?iTX7_edViAVT!y<$SdF8Dj|e31XZy&ovxsneL>$$27N1uwxx$^%KSAM*p*v^$7s;eZ-;C6~o5o&b^?FlkjA5cQi85 z0m>$dlSz~VpG!CVFou1B!1G^T$N%)NQ7f@5358?PWeuqwSc21{{E=|v%V6Y{KmM~X ze&N4bi*0~Y2<0a*gyBO7tiKIVpMjHBTqnKs$=5*4nr6ol9Ia!X^wiMM`|p9LK8GQN zLmGz+4B>17l&Ac%bmQ9DXy^m?Nc0~zBoTA!rND>iC4cqEk90`Crx`ayDv;4jn zdVhd?7yoT}cBi0i>6+sP{cB(3dtdyE<@k$>81=dHXYqYZ#%x{5sS!Amm=Z#g~;pBWkO;b(U2vWvS3HQ z2|lAWQ&xrtQb+F@rcK{j|M(ew1TvV>UBOfpdSjh71w#RW+0%7c*^bb4V3~C19Sf$n z-Km08=J-=scz6dcxH5)aL^;$HT|QWpd9)6kk%i{yo~cP4*<;W%tkG1_5!;I7ATvX8 z1X*wd*#3-GOxUiODxTecT)cUkKAt19$bz2rGsIb7b3*q;QMi!=0IvsBB}Q@ed}yXJ`Ij`l#y zi^{0n75MBC*Y2pCS#F7jtaI8x<&11g1}^>1tLgAtzRo`K$%_N6kkZ+6IoyK&j=DNSv6{d$t!AorC;^_};Tr;9vzx`- z^j_d^r@_m5J)LjKB~^3%b>8{9*RcMfmm)~9!aI4HICb$0lC?z?;gtA`e3YQL_^2eP4bix>-p?_ZiuP! zlP15L^zxX|r-eFw?TBj2t_|Wg!Mmp_EG5h2u0;Kc9ne`P?4{qkd<(7>mMg>pJCaBya3^F z4-O4kY(CI88f<>R?%-xHHtA0jJhwBQ#9OEXI0yR4(GN&;-4eP_dWv&XvaFkK6I~xS X$#t)N#_qzeDDRIr7YW&e|2F;xL{p2J literal 0 HcmV?d00001 diff --git a/tests/datasets/test_dior.py b/tests/datasets/test_dior.py new file mode 100644 index 00000000000..5463da333e1 --- /dev/null +++ b/tests/datasets/test_dior.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch + +from torchgeo.datasets import DatasetNotFoundError, DIOR + + +class TestDIOR: + @pytest.fixture(params=['train', 'val', 'test']) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> DIOR: + url = os.path.join('tests', 'data', 'dior', '{}') + monkeypatch.setattr(DIOR, 'url', url) + + files = { + 'trainval': { + 'images': { + 'filename': 'Images_trainval.zip', + 'md5': '17b9a13f7f9e30bc04f9d70b4bb0a47b', + }, + 'labels': { + 'filename': 'Annotations_trainval.zip', + 'md5': '887a590a2872be81f00f21f502a7cb56', + }, + }, + 'test': { + 'images': { + 'filename': 'Images_test.zip', + 'md5': 'e14666a09788bfb0d5ad39a82f7da946', + } + }, + } + monkeypatch.setattr(DIOR, 'files', files) + root = tmp_path + split = request.param + transforms = nn.Identity() + return DIOR( + root=root, split=split, transforms=transforms, download=True, checksum=True + ) + + def test_already_downloaded(self, dataset: DIOR) -> None: + DIOR(root=dataset.root, download=True) + + def test_not_yet_extracted(self, tmp_path: Path) -> None: + files = [ + 'Images_trainval.zip', + 'Annotations_trainval.zip', + 'Images_test.zip', + 'sample_df.parquet', + ] + for path in files: + shutil.copyfile( + os.path.join('tests', 'data', 'dior', path), + os.path.join(str(tmp_path), path), + ) + + DIOR(root=tmp_path) + + def test_getitem(self, dataset: DIOR) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert x['image'].shape[0] == 3 + assert x['image'].ndim == 3 + assert isinstance(x['image'], torch.Tensor) + if dataset.split != 'test': + assert isinstance(x['labels'], torch.Tensor) + assert isinstance(x['boxes'], torch.Tensor) + + def test_len(self, dataset: DIOR) -> None: + if dataset.split == 'train': + assert len(dataset) == 4 + else: + assert len(dataset) == 2 + + def test_corrupted(self, tmp_path: Path) -> None: + with open(os.path.join(tmp_path, 'Images_trainval.zip'), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): + DIOR(root=tmp_path, checksum=True) + + def test_not_found(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + DIOR(tmp_path) + + def test_plot(self, dataset: DIOR) -> None: + if dataset.split != 'test': + x = dataset[0].copy() + dataset.plot(x, suptitle='Test') + plt.close() + + def test_plot_prediction(self, dataset: DIOR) -> None: + if dataset.split != 'test': + x = dataset[0].copy() + x['prediction_boxes'] = x['boxes'].clone() + dataset.plot(x, suptitle='Prediction') + plt.close() diff --git a/torchgeo/datasets/dior.py b/torchgeo/datasets/dior.py new file mode 100644 index 00000000000..07af4e8e8cc --- /dev/null +++ b/torchgeo/datasets/dior.py @@ -0,0 +1,393 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""DIOR dataset.""" + +import glob +import os +from collections.abc import Callable +from typing import Any +from xml.etree import ElementTree + +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.figure import Figure +from PIL import Image +from torch import Tensor +import pandas as pd + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import ( + Path, + check_integrity, + download_and_extract_archive, + extract_archive, + download_url, +) + + +def parse_pascal_voc(path: Path) -> dict[str, Any]: + """Read a PASCAL VOC annotation file. + + Args: + path: path to xml file + + Returns: + dict of image filename, bounding box coords, and class labels + """ + et = ElementTree.parse(path) + element = et.getroot() + filename = element.find('filename').text # type: ignore[union-attr] + labels, bboxes = [], [] + + for obj in element.findall('object'): + bndbox = obj.find('bndbox') + bbox = [ + int(bndbox.find('xmin').text), # type: ignore[union-attr, arg-type] + int(bndbox.find('ymin').text), # type: ignore[union-attr, arg-type] + int(bndbox.find('xmax').text), # type: ignore[union-attr, arg-type] + int(bndbox.find('ymax').text), # type: ignore[union-attr, arg-type] + ] + label = obj.find('name').text + bboxes.append(bbox) + labels.append(label) + + return dict(filename=filename, bboxes=bboxes, labels=labels) + + +class DIOR(NonGeoDataset): + """DIOR dataset. + + `DIOR `_ dataset contains horizontal bounding box + annotations of Google Earth Aerial RGB imagery. The test split does not contain bounding + box annotations and labels + + Dataset features: + + * 20 classes + * 192,472 manually annotated bounding box instances + + Dataset format: + + * Images are three channel .jpg files. + * Annotations are in `Pascal VOC XML format + `_ + + + Classes: + + 0. Airplane + 1. Airport + 2. Baseball Field + 3. Basketball Court + 4. Bridge + 5. Chimney + 6. Dam + 7. Expressway Service Area + 8. Expressway Toll Station + 9. Golf Field + 10. Ground Track Field + 11. Harbor + 12. Overpass + 13. Ship + 14. Stadium + 15. Storage Tank + 16. Tennis Court + 17. Train Station + 18. Vehicle + 19. Windmill + + + If you use this dataset in your research, please cite the following paper: + + * https://arxiv.org/abs/1909.00133 + + + .. versionadded:: 0.7 + """ + + url = 'https://huggingface.co/datasets/torchgeo/dior/resolve/main/{}' + + files = { + 'trainval': { + 'images': { + 'filename': 'Images_trainval.zip', + 'md5': '070e9314120403e5c965d12fe5321cb0', + }, + 'labels': { + 'filename': 'Annotations_trainval.zip', + 'md5': '90e045de37255c5919bbecf659b72c1a', + }, + }, + 'test': { + 'images': { + 'filename': 'Images_test.zip', + 'md5': '97f3cbc86de0867624a6a34190c694ae', + } + }, + } + + valid_splits = ('train', 'val', 'test') + + classes = ( + 'airplane', + 'airport', + 'baseballfield', + 'basketballcourt', + 'bridge', + 'chimney', + 'dam', + 'expresswayservicearea', + 'expresswaytollstation', + 'golffield', + 'groundtrackfield', + 'harbor', + 'overpass', + 'ship', + 'stadium', + 'storagetank', + 'tenniscourt', + 'trainstation', + 'vehicle', + 'windmill', + ) + + def __init__( + self, + root: Path = 'data', + split: str = 'train', + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new DIOR dataset instance. + + Args: + root: root directory where dataset can be found + split: split of the dataset to use, one of 'train', 'val', 'test' + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. + AssertionError: If *split* argumnet is invalid. + """ + self.root = root + self.transforms = transforms + self.checksum = checksum + self.download = download + + assert split in self.valid_splits, f'Split must be one of {self.valid_splits}.' + self.split = split + + self._verify() + + self.sample_df = pd.read_parquet(os.path.join(self.root, 'sample_df.parquet')) + + self.sample_df = self.sample_df[ + self.sample_df['split'] == self.split + ].reset_index(drop=True) + + self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)} + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.sample_df) + + def __getitem__(self, idx: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + idx: index to return + + Returns: + data and label at that index + """ + sample = self.sample_df.iloc[idx] + + image = self._load_image(os.path.join(self.root, sample['image_path'])) + + if self.split != 'test': + boxes, labels = self._load_target( + os.path.join(self.root, sample['label_path']) + ) + + sample = {'image': image, 'boxes': boxes, 'labels': labels} + else: + sample = {'image': image} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _load_image(self, path: Path) -> Tensor: + """Load a single image. + + Args: + path: path to the image + + Returns: + the image + """ + with Image.open(path) as img: + array: np.typing.NDArray[np.int_] = np.array(img.convert('RGB')) + tensor: Tensor = torch.from_numpy(array) + # Convert from HxWxC to CxHxW + tensor = tensor.permute((2, 0, 1)) + return tensor + + def _load_target(self, path: Path) -> tuple[Tensor, Tensor]: + """Load the target mask for a single image. + + Args: + path: path to the annotation file + + Returns: + the target bounding boxes and labels + """ + parsed = parse_pascal_voc(path) + boxes = torch.tensor(parsed['bboxes'], dtype=torch.float32) + labels = torch.tensor( + [self.class_to_idx[label] for label in parsed['labels']] + ).long() + return boxes, labels + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + df_path = os.path.join(self.root, 'sample_df.parquet') + exists = [] + if os.path.exists(df_path): + exists.append(True) + df = pd.read_parquet(df_path) + df = df[df['split'] == self.split].reset_index(drop=True) + for idx, row in df.iterrows(): + if os.path.exists(os.path.join(self.root, row['image_path'])): + exists.append(True) + else: + exists.append(False) + else: + exists.append(False) + + if all(exists): + return + + exists = [] + if self.split in ['train', 'val']: + files = self.files['trainval'] + else: + files = self.files['test'] + + for key in files: + filename = files[key]['filename'] + md5 = files[key]['md5'] + path = os.path.join(self.root, filename) + if os.path.exists(path): + if self.checksum and not check_integrity(path, md5): + raise RuntimeError('Dataset found, but corrupted.') + extract_archive(path) + exists.append(True) + else: + exists.append(False) + + if all(exists): + return + + if not self.download: + raise DatasetNotFoundError(self) + + self._download() + + def _download(self) -> None: + """Downlaod the dataset and extract it.""" + if self.split in ['train', 'val']: + files = self.files['trainval'] + else: + files = self.files['test'] + + for key in files: + filename = files[key]['filename'] + md5 = files[key]['md5'] + download_and_extract_archive( + self.url.format(filename), + self.root, + filename=filename, + md5=md5 if self.checksum else None, + ) + + # download the sample_df.parquet file + download_url( + self.url.format('sample_df.parquet'), + self.root, + filename='sample_df.parquet', + ) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + box_alpha: float = 0.7, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + box_alpha: alpha value for boxes + + Returns: + a matplotlib Figure with the rendered sample + """ + image = sample['image'].permute((1, 2, 0)).numpy() + boxes = sample['boxes'].numpy() + labels = sample['labels'].numpy() + + fig, axs = plt.subplots(ncols=1, figsize=(10, 10)) + + axs.imshow(image) + axs.axis('off') + + cm = plt.get_cmap('gist_rainbow') + + for box, label_idx in zip(boxes, labels): + color = cm(label_idx / len(self.classes)) + label = self.classes[label_idx] + + # Horizontal box: [xmin, ymin, xmax, ymax] + x1, y1, x2, y2 = box + rect = patches.Rectangle( + (x1, y1), + x2 - x1, + y2 - y1, + linewidth=2, + alpha=box_alpha, + linestyle='solid', + edgecolor=color, + facecolor='none', + ) + axs.add_patch(rect) + # Add label above box + axs.text( + x1, + y1 - 5, + label, + color='white', + fontsize=8, + bbox=dict(facecolor=color, alpha=box_alpha), + ) + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig From 08cd8395854709e58ba870f082976c8acc908835 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Mon, 10 Feb 2025 17:37:09 +0000 Subject: [PATCH 2/8] init --- torchgeo/datasets/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 8177120c2a7..fa976e9cc03 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -37,6 +37,7 @@ from .deepglobelandcover import DeepGlobeLandCover from .dfc2022 import DFC2022 from .digital_typhoon import DigitalTyphoon +from .dior import DIOR from .eddmaps import EDDMapS from .enviroatlas import EnviroAtlas from .errors import DatasetNotFoundError, DependencyNotFoundError, RGBBandsMissingError @@ -156,6 +157,7 @@ 'BRIGHTDFC2025', 'CDL', 'COWC', + 'DIOR', 'DFC2022', 'ETCI2021', 'EUDEM', From 4409f164f05adb48ee58ef7104b8e58e8e65af63 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Mon, 10 Feb 2025 17:38:45 +0000 Subject: [PATCH 3/8] naming convention --- tests/datasets/test_dior.py | 2 +- torchgeo/datasets/dior.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/datasets/test_dior.py b/tests/datasets/test_dior.py index 5463da333e1..3321490e542 100644 --- a/tests/datasets/test_dior.py +++ b/tests/datasets/test_dior.py @@ -74,7 +74,7 @@ def test_getitem(self, dataset: DIOR) -> None: assert x['image'].ndim == 3 assert isinstance(x['image'], torch.Tensor) if dataset.split != 'test': - assert isinstance(x['labels'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) assert isinstance(x['boxes'], torch.Tensor) def test_len(self, dataset: DIOR) -> None: diff --git a/torchgeo/datasets/dior.py b/torchgeo/datasets/dior.py index 07af4e8e8cc..e45ada374fb 100644 --- a/torchgeo/datasets/dior.py +++ b/torchgeo/datasets/dior.py @@ -221,7 +221,7 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: os.path.join(self.root, sample['label_path']) ) - sample = {'image': image, 'boxes': boxes, 'labels': labels} + sample = {'image': image, 'bbox_xyxy': boxes, 'label': labels} else: sample = {'image': image} @@ -351,7 +351,7 @@ def plot( """ image = sample['image'].permute((1, 2, 0)).numpy() boxes = sample['boxes'].numpy() - labels = sample['labels'].numpy() + labels = sample['label'].numpy() fig, axs = plt.subplots(ncols=1, figsize=(10, 10)) From d80136d8022cc79a26b9aec3252d9a5b312eb12c Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Tue, 11 Feb 2025 09:29:51 +0100 Subject: [PATCH 4/8] tests --- tests/data/dior/data.py | 6 +++--- tests/datasets/test_dior.py | 11 ++--------- torchgeo/datasets/dior.py | 27 +++++++++++++-------------- 3 files changed, 18 insertions(+), 26 deletions(-) diff --git a/tests/data/dior/data.py b/tests/data/dior/data.py index a428a41c7e3..f89f61a3e09 100644 --- a/tests/data/dior/data.py +++ b/tests/data/dior/data.py @@ -7,10 +7,10 @@ import os import shutil import xml.etree.ElementTree as ET -import pandas as pd + import numpy as np +import pandas as pd from PIL import Image -from pathlib import Path # Constants SIZE = 32 # DIOR uses 800x800 but smaller for tests @@ -78,7 +78,7 @@ def create_annotation(path: str, image_name: str) -> None: tree.write(path) -def create_dataset(): +def create_dataset() -> None: """Create dummy DIOR dataset.""" root = os.getcwd() diff --git a/tests/datasets/test_dior.py b/tests/datasets/test_dior.py index 3321490e542..4dae4ea50a7 100644 --- a/tests/datasets/test_dior.py +++ b/tests/datasets/test_dior.py @@ -12,7 +12,7 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import DatasetNotFoundError, DIOR +from torchgeo.datasets import DIOR, DatasetNotFoundError class TestDIOR: @@ -75,7 +75,7 @@ def test_getitem(self, dataset: DIOR) -> None: assert isinstance(x['image'], torch.Tensor) if dataset.split != 'test': assert isinstance(x['label'], torch.Tensor) - assert isinstance(x['boxes'], torch.Tensor) + assert isinstance(x['bbox_xyxy'], torch.Tensor) def test_len(self, dataset: DIOR) -> None: if dataset.split == 'train': @@ -98,10 +98,3 @@ def test_plot(self, dataset: DIOR) -> None: x = dataset[0].copy() dataset.plot(x, suptitle='Test') plt.close() - - def test_plot_prediction(self, dataset: DIOR) -> None: - if dataset.split != 'test': - x = dataset[0].copy() - x['prediction_boxes'] = x['boxes'].clone() - dataset.plot(x, suptitle='Prediction') - plt.close() diff --git a/torchgeo/datasets/dior.py b/torchgeo/datasets/dior.py index e45ada374fb..e0f8d32c89f 100644 --- a/torchgeo/datasets/dior.py +++ b/torchgeo/datasets/dior.py @@ -3,20 +3,19 @@ """DIOR dataset.""" -import glob import os from collections.abc import Callable -from typing import Any +from typing import Any, ClassVar from xml.etree import ElementTree import matplotlib.patches as patches import matplotlib.pyplot as plt import numpy as np +import pandas as pd import torch from matplotlib.figure import Figure from PIL import Image from torch import Tensor -import pandas as pd from .errors import DatasetNotFoundError from .geo import NonGeoDataset @@ -24,8 +23,8 @@ Path, check_integrity, download_and_extract_archive, - extract_archive, download_url, + extract_archive, ) @@ -51,7 +50,7 @@ def parse_pascal_voc(path: Path) -> dict[str, Any]: int(bndbox.find('xmax').text), # type: ignore[union-attr, arg-type] int(bndbox.find('ymax').text), # type: ignore[union-attr, arg-type] ] - label = obj.find('name').text + label = obj.find('name').text # type: ignore[union-attr, arg-type] bboxes.append(bbox) labels.append(label) @@ -111,7 +110,7 @@ class DIOR(NonGeoDataset): url = 'https://huggingface.co/datasets/torchgeo/dior/resolve/main/{}' - files = { + files: ClassVar[dict[str, dict[str, dict[str, str]]]] = { 'trainval': { 'images': { 'filename': 'Images_trainval.zip', @@ -212,18 +211,18 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: Returns: data and label at that index """ - sample = self.sample_df.iloc[idx] + row = self.sample_df.iloc[idx] - image = self._load_image(os.path.join(self.root, sample['image_path'])) + image = self._load_image(os.path.join(self.root, row['image_path'])) + + sample: dict[str, Tensor] = {'image': image} if self.split != 'test': boxes, labels = self._load_target( - os.path.join(self.root, sample['label_path']) + os.path.join(self.root, row['label_path']) ) - - sample = {'image': image, 'bbox_xyxy': boxes, 'label': labels} - else: - sample = {'image': image} + sample['bbox_xyxy'] = boxes + sample['label'] = labels if self.transforms is not None: sample = self.transforms(sample) @@ -350,7 +349,7 @@ def plot( a matplotlib Figure with the rendered sample """ image = sample['image'].permute((1, 2, 0)).numpy() - boxes = sample['boxes'].numpy() + boxes = sample['bbox_xyxy'].numpy() labels = sample['label'].numpy() fig, axs = plt.subplots(ncols=1, figsize=(10, 10)) From 3863f0118dd2214929c3f6773c5d45248e6bc9d4 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Tue, 11 Feb 2025 09:36:54 +0100 Subject: [PATCH 5/8] tests --- torchgeo/datasets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 6b2dd80f4ec..fea07cf9931 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -158,8 +158,8 @@ 'BRIGHTDFC2025', 'CDL', 'COWC', - 'DIOR', 'DFC2022', + 'DIOR', 'ETCI2021', 'EUDEM', 'FAIR1M', From 7d1f3c3233e6d0239984a55a76365c9056c6a7bd Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Tue, 11 Feb 2025 09:51:34 +0100 Subject: [PATCH 6/8] import skip --- tests/datasets/test_dior.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/datasets/test_dior.py b/tests/datasets/test_dior.py index 4dae4ea50a7..6d48f77759c 100644 --- a/tests/datasets/test_dior.py +++ b/tests/datasets/test_dior.py @@ -14,6 +14,8 @@ from torchgeo.datasets import DIOR, DatasetNotFoundError +pytest.importorskip('pyarrow') + class TestDIOR: @pytest.fixture(params=['train', 'val', 'test']) From cb5e807994ddaeffb8c0d7802a869a7eb0f8bbc1 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Wed, 12 Feb 2025 08:34:08 +0100 Subject: [PATCH 7/8] myp --- torchgeo/datasets/dior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/dior.py b/torchgeo/datasets/dior.py index e0f8d32c89f..008dd5623d4 100644 --- a/torchgeo/datasets/dior.py +++ b/torchgeo/datasets/dior.py @@ -50,7 +50,7 @@ def parse_pascal_voc(path: Path) -> dict[str, Any]: int(bndbox.find('xmax').text), # type: ignore[union-attr, arg-type] int(bndbox.find('ymax').text), # type: ignore[union-attr, arg-type] ] - label = obj.find('name').text # type: ignore[union-attr, arg-type] + label = obj.find('name').text # type: ignore[union-attr] bboxes.append(bbox) labels.append(label) From 391e74753c5e3b4babef6a9b98b0e71a3f13a1bc Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Mon, 17 Feb 2025 08:43:40 +0100 Subject: [PATCH 8/8] requests --- torchgeo/datasets/dior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/dior.py b/torchgeo/datasets/dior.py index 008dd5623d4..0f6c412f2d6 100644 --- a/torchgeo/datasets/dior.py +++ b/torchgeo/datasets/dior.py @@ -108,7 +108,7 @@ class DIOR(NonGeoDataset): .. versionadded:: 0.7 """ - url = 'https://huggingface.co/datasets/torchgeo/dior/resolve/main/{}' + url = 'https://hf.co/datasets/torchgeo/dior/resolve/ec7be9567d2e08eb3d3401c15a52ee2145d0ef01/{}' files: ClassVar[dict[str, dict[str, dict[str, str]]]] = { 'trainval': {