diff --git a/.gitignore b/.gitignore index 9790f49..5c836c5 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,8 @@ logs datasets .DS_Store -*.log +.vscode +logs_v2 +heatmap +184_checkpoints +.ipynb_checkpoints \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f288702 --- /dev/null +++ b/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/README.md b/README.md new file mode 100644 index 0000000..5cdf258 --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +# SAM +[Semantic Graph Representation Learning for Handwritten Mathematical Expression Recognition (ICDAR 2023)](https://link.springer.com/chapter/10.1007/978-3-031-41676-7_9) + +The code will be released after the ICDAR2023 meeting! + diff --git a/config.yaml b/config_v2.yaml similarity index 69% rename from config.yaml rename to config_v2.yaml index c067eea..789da00 100644 --- a/config.yaml +++ b/config_v2.yaml @@ -1,5 +1,17 @@ # 实验名称 -experiment: "v1_l2-loss" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_ori_train_with_gpus1_bs8_lr1" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_gpus1_bs8_lr1" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_gpus1_bs8_lr1_ddp" # "ori_CAN_in_full_crohme" +experiment: "can-l2-context-word" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_ori_train_with_gpus1_bs8_lr1" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_gpus1_bs8_lr1" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_gpus1_bs8_lr1_ddp" # "ori_CAN_in_full_crohme" + +sim_loss: + type: l2 + use_flag: True +context_loss: False +word_state_loss: False + + +counting_decoder: + use_flag: True + in_channel: 684 + out_channel: 111 # 随机种子 seed: 20211024 @@ -7,10 +19,10 @@ seed: 20211024 # 训练参数 epochs: 200 batch_size: 8 # 8 -workers: 0 # 0 +workers: 5 # 0 train_parts: 1 valid_parts: 1 -valid_start: 20 # 1000000000 +valid_start: 100 # 1000000000 save_start: 0 # 220 optimizer: Adadelta @@ -22,9 +34,9 @@ eps: 1e-6 weight_decay: 1e-4 beta: 0.9 -output_counting_feature: False -output_channel_attn_feature: False -counting_loss_ratio: 1 +# output_counting_feature: False +# output_channel_attn_feature: False +# counting_loss_ratio: 1 dropout: True dropout_ratio: 0.5 @@ -69,7 +81,7 @@ encoder: out_channel: 684 decoder: - net: Decoder_v1 + net: Decoder_v3 cell: 'GRU' input_size: 256 hidden_size: 256 @@ -79,15 +91,13 @@ attention: attention_dim: 512 word_conv_kernel: 1 -sim_loss: - type: l2 whiten_type: None max_step: 256 optimizer_save: False -finetune: False +finetune: True checkpoint_dir: 'checkpoints' checkpoint: "" -log_dir: 'logs' +log_dir: 'logs_v2' diff --git a/counting_utils.py b/counting_utils.py new file mode 100644 index 0000000..6219123 --- /dev/null +++ b/counting_utils.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +import os + +def gen_counting_label(labels, channel, tag): + b, t = labels.size() + counting_labels = torch.zeros((b, channel)) + if tag: + ignore = [0, 1, 107, 108, 109, 110] + else: + ignore = [] + for i in range(b): + for j in range(t): + k = labels[i][j] + if k in ignore: + continue + else: + counting_labels[i][k] += 1 + return counting_labels.detach() diff --git a/dataset.py b/dataset.py index 0c2d50e..fdca8d5 100644 --- a/dataset.py +++ b/dataset.py @@ -2,6 +2,7 @@ import time import pickle as pkl from torch.utils.data import DataLoader, Dataset, RandomSampler, DistributedSampler +from counting_utils import gen_counting_label class HMERDataset(Dataset): @@ -33,6 +34,10 @@ def __init__(self, params, image_path, label_path, words, is_train=True, use_aug self.reverse_color = self.params['data_process']['reverse_color'] if 'data_process' in params else False self.equal_range = self.params['data_process']['equal_range'] if 'data_process' in params else False + with open(self.params['matrix_path'], 'rb') as f: + matrix = pkl.load(f) + self.matrix = torch.Tensor(matrix) + def __len__(self): # assert len(self.images) == len(self.labels) return len(self.labels) @@ -55,7 +60,28 @@ def __getitem__(self, idx): words = self.words.encode(labels) + [0] words = torch.LongTensor(words) return image, words - + + def gen_matrix(self, labels): + (B, L), device = labels.shape, labels.device + matrix = [] + for i in range(B): + _L = [] + label = labels[i] + for x in range(L): + _T = [] + for y in range(L): + if x == y: + _T.append(1.) + else: + if label[x] == label[y] or label[x] == 0 or label[y] == 0: + _T.append(0.) + else: + _T.append(self.matrix[label[x], label[y]]) + _L.append(_T) + matrix.append(_L) + matrix = torch.tensor(matrix).to(device) + return matrix.detach() + def collate_fn(self, batch_images): max_width, max_height, max_length = 0, 0, 0 batch, channel = len(batch_images), batch_images[0][0].shape[0] @@ -80,7 +106,9 @@ def collate_fn(self, batch_images): l = proper_items[i][1].shape[0] labels[i][:l] = proper_items[i][1] labels_masks[i][:l] = 1 - return images, image_masks, labels, labels_masks + matrix = self.gen_matrix(labels) + counting_labels = gen_counting_label(labels, self.params['counting_decoder']['out_channel'], True) + return images, image_masks, labels, labels_masks, matrix, counting_labels def get_crohme_dataset(params): diff --git a/inference.py b/inference.py deleted file mode 100644 index 2eda079..0000000 --- a/inference.py +++ /dev/null @@ -1,122 +0,0 @@ -import os -import cv2 -import argparse -import torch -import torch.nn as nn -import json -import pickle as pkl -from tqdm import tqdm -import time - -from utils import load_config, load_checkpoint, compute_edit_distance -from models.infer_model import Inference -from dataset import Words - -parser = argparse.ArgumentParser(description='model testing') -parser.add_argument('--dataset', default='CROHME', type=str, help='数据集名称') -parser.add_argument('--config', default='CROHME', type=str, help='数据集名称') -parser.add_argument('--image_path', default='/liuzhuang7/CROHME/19_test_images.pkl', type=str, help='测试image路径') -parser.add_argument('--label_path', default='/liuzhuang7/CROHME/19_test_labels.txt', type=str, help='测试label路径') -parser.add_argument('--word_path', default='/liuzhuang7/CROHME/words_dict.txt', type=str, help='测试dict路径') -parser.add_argument('--model_path', default='', type=str, help='path of trained model') - -parser.add_argument('--draw_map', action='store_true') -args = parser.parse_args() - -if not args.dataset: - print('请提供数据集名称') - exit(-1) - -"""加载config文件""" -params = load_config(args.config) - -# os.environ['CUDA_VISIBLE_DEVICES'] = '0' -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -params['device'] = device -words = Words(args.word_path) -params['word_num'] = len(words) -params['words'] = words -if args.model_path != '': - params['checkpoint'] = args.model_path -print(params['checkpoint']) - -if 'use_label_mask' not in params: - params['use_label_mask'] = False -print(params['decoder']['net']) -model = Inference(params, draw_map=args.draw_map) -model = model.to(device) - -load_checkpoint(model, None, params['checkpoint']) -model.eval() - -with open(args.image_path, 'rb') as f: - images = pkl.load(f) - -with open(args.label_path) as f: - lines = f.readlines() - -line_right = 0 -e1, e2, e3 = 0, 0, 0 -bad_case = {} -model_time = 0 -mae_sum, mse_sum = 0, 0 -num = 1 -with tqdm(lines) as pbar, torch.no_grad(): - for line in pbar: - name, *labels = line.split() - # if name!='20_em_33': - # continue - name = name.split('.')[0] if name.endswith('jpg') else name - input_labels = labels - labels = ' '.join(labels) - img = images[name] - if params['data_process']['reverse_color']: - img = 255 - img - if params['data_process']['equal_range']: - img = (img / 255 - 0.5) * 2 - else: - img = img / 255 - img = torch.Tensor(img) - img = img.unsqueeze(0).unsqueeze(0) - img = img.to(device) - a = time.time() - - input_labels = words.encode(input_labels) - input_labels = torch.LongTensor(input_labels) - input_labels = input_labels.unsqueeze(0).to(device) - - probs, _, mae, mse = model(img, input_labels, os.path.join(params['decoder']['net'], name)) - mae_sum += mae - mse_sum += mse - model_time += (time.time() - a) - - prediction = words.decode(probs) - if prediction == labels: - line_right += 1 - else: - bad_case[name] = { - 'label': labels, - 'predi': prediction - } - # print(name, prediction, labels) - - distance = compute_edit_distance(prediction, labels) - if distance <= 1: - e1 += 1 - if distance <= 2: - e2 += 1 - if distance <= 3: - e3 += 1 - pbar.set_description(f'ExpRate: {line_right/num: .4f}') - num += 1 - -print(f'model time: {model_time}') -print(f'ExpRate: {line_right / len(lines)}') -print(f'mae: {mae_sum / len(lines)}') -print(f'mse: {mse_sum / len(lines)}') -print(f'e1: {e1 / len(lines)}') -print(f'e2: {e2 / len(lines)}') -print(f'e3: {e3 / len(lines)}') - -with open(f'{params["decoder"]["net"]}_bad_case.json','w') as f: - json.dump(bad_case,f,ensure_ascii=False) diff --git a/models/__init__.py b/models/__init__.py index 6dfb1c3..af29e9e 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1 +1,4 @@ -from models.decoder.decoder_v1 import Decoder_v1 \ No newline at end of file +from models.decoder.decoder_v1 import Decoder_v1 +from models.decoder.decoder_v2 import Decoder_v1 as Decoder_v2 +from models.decoder.decoder_v3 import Decoder_v1 as Decoder_v3 +from models.counting import CountingDecoder as counting_decoder \ No newline at end of file diff --git a/models/backbone.py b/models/backbone.py index 9c4fc81..61a1b32 100644 --- a/models/backbone.py +++ b/models/backbone.py @@ -6,9 +6,11 @@ import models from models.densenet import DenseNet +from einops.layers.torch import Rearrange +from traceback import print_exc class Model(nn.Module): - def __init__(self, params=None): + def __init__(self, params={}): super(Model, self).__init__() self.params = params @@ -21,11 +23,6 @@ def __init__(self, params=None): self.use_label_mask = params['use_label_mask'] self.encoder = DenseNet(params=self.params) - # self.in_channel = params['counting_decoder']['in_channel'] - # self.out_channel = params['counting_decoder']['out_channel'] - - # self.output_counting_feature = params['output_counting_feature'] if 'output_counting_feature' in params else False - # self.channel_attn_feature = params['output_channel_attn_feature'] if 'output_channel_attn_feature' in params else False self.decoder = getattr(models, params['decoder']['net'])(params=self.params) self.cross = nn.CrossEntropyLoss(reduction='none') if self.use_label_mask else nn.CrossEntropyLoss() @@ -35,19 +32,62 @@ def __init__(self, params=None): """经过cnn后 长宽与原始尺寸比缩小的比例""" self.ratio = params['densenet']['ratio'] - def forward(self, images, images_mask, labels, labels_mask, is_train=True): + if self.params['context_loss'] or self.params['word_state_loss']: + self.cma_context = nn.Sequential( + nn.Linear(params['encoder']['out_channel'], params['decoder']['input_size']), + Rearrange("b l h->b h l"), + nn.BatchNorm1d(params['decoder']['input_size']), + Rearrange("b h l->b l h"), + nn.ReLU() + ) + self.cma_word = nn.Sequential( + nn.Linear(params['decoder']['input_size'], params['decoder']['input_size']), + Rearrange("b l h->b h l"), + nn.BatchNorm1d(params['decoder']['input_size']), + Rearrange("b h l->b l h"), + nn.ReLU() + ) + + def forward(self, images, images_mask, labels, labels_mask, matrix=None, counting_labels=None, is_train=True): cnn_features = self.encoder(images) - word_probs, word_alphas, embedding = self.decoder(cnn_features, labels, images_mask, labels_mask, is_train=is_train) - + word_probs, word_alphas, embedding = self.decoder(cnn_features, labels, images_mask, labels_mask, counting_labels=counting_labels, is_train=is_train) + + context_loss, word_state_loss, word_sim_loss, counting_loss = 0, 0, 0, 0 + embedding, word_context_vec_list, word_out_state_list, _, counting_loss = embedding + if self.params['context_loss'] or self.params['word_state_loss'] and is_train: + if 'context_loss' in self.params and self.params['context_loss']: + word_context_vec_list = torch.stack(word_context_vec_list, 1) + context_embedding = self.cma_context(word_context_vec_list) + context_loss = self.cal_cam_loss_v2(context_embedding, labels, matrix) + if 'word_state_loss' in self.params and self.params['word_state_loss']: + word_out_state_list = torch.stack(word_out_state_list, 1) + word_state_embedding = self.cma_word(word_out_state_list) + word_state_loss = self.cal_cam_loss_v2(word_state_embedding, labels, matrix) + word_loss = self.cross(word_probs.contiguous().view(-1, word_probs.shape[-1]), labels.view(-1)) word_average_loss = (word_loss * labels_mask.view(-1)).sum() / (labels_mask.sum() + 1e-10) if self.use_label_mask else word_loss - word_sim_loss = self.cal_word_similarity(embedding) + if 'sim_loss' in self.params and self.params['sim_loss']['use_flag']: + word_sim_loss = self.cal_word_similarity(embedding) - return word_probs, (word_average_loss, word_sim_loss) + return word_probs, (word_average_loss, word_sim_loss, context_loss, word_state_loss, counting_loss) + def cal_cam_loss_v2(self, word_embedding, labels, matrix): + (B, L, H), device = word_embedding.shape, word_embedding.device + + W = torch.matmul(word_embedding, word_embedding.transpose(-1, -2)) # B L L + denom = torch.matmul(word_embedding.unsqueeze(-2), word_embedding.unsqueeze(-1)).squeeze(-1) ** (0.5) + # B L 1 H @ B L H 1 -> B L 1 1 + cosine = W / (denom @ denom.transpose(-1, -2)) + sim_mask = matrix != 0 + if self.sim_loss_type == 'l1': + loss = abs((cosine - matrix) * sim_mask) + else: + loss = (cosine - matrix) ** 2 * sim_mask + return loss.sum() / B / (labels != 0).sum() + def cal_word_similarity(self, word_embedding): num = word_embedding @ word_embedding.transpose(1,0) diff --git a/models/counting.py b/models/counting.py new file mode 100644 index 0000000..399af43 --- /dev/null +++ b/models/counting.py @@ -0,0 +1,45 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ChannelAtt(nn.Module): + def __init__(self, channel, reduction): + super(ChannelAtt, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel//reduction), + nn.ReLU(), + nn.Linear(channel//reduction, channel), + nn.Sigmoid()) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class CountingDecoder(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size): + super(CountingDecoder, self).__init__() + self.in_channel = in_channel + self.out_channel = out_channel + self.trans_layer = nn.Sequential( + nn.Conv2d(self.in_channel, 512, kernel_size=kernel_size, padding=kernel_size//2, bias=False), + nn.BatchNorm2d(512)) + self.channel_att = ChannelAtt(512, 16) + self.pred_layer = nn.Sequential( + nn.Conv2d(512, self.out_channel, kernel_size=1, bias=False), + nn.Sigmoid()) + + def forward(self, x, mask): + b, c, h, w = x.size() + x = self.trans_layer(x) + x = self.channel_att(x) + x = self.pred_layer(x) + if mask is not None: + x = x * mask + x = x.view(b, self.out_channel, -1) + x1 = torch.sum(x, dim=-1) + return x1, x.view(b, self.out_channel, h, w) diff --git a/models/decoder/decoder_v2.py b/models/decoder/decoder_v2.py new file mode 100644 index 0000000..03dcab8 --- /dev/null +++ b/models/decoder/decoder_v2.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +from models.decoder.attention import Attention + + +class Decoder_v1(nn.Module): + def __init__(self, params): + super(Decoder_v1, self).__init__() + self.params = params + self.input_size = params['decoder']['input_size'] + self.hidden_size = params['decoder']['hidden_size'] + self.out_channel = params['encoder']['out_channel'] + self.attention_dim = params['attention']['attention_dim'] + self.dropout_prob = params['dropout'] + self.device = params['device'] + self.word_num = params['word_num'] + # self.counting_num = params['counting_decoder']['out_channel'] + + """经过cnn后 长宽与原始尺寸比缩小的比例""" + self.ratio = params['densenet']['ratio'] + + # init hidden state + self.init_weight = nn.Linear(self.out_channel, self.hidden_size) + # word embedding + self.embedding = nn.Embedding(self.word_num, self.input_size) + # word gru + self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size) + self.word_output_gru = nn.GRUCell(self.out_channel, self.hidden_size) + # attention + self.word_attention = Attention(params) + self.encoder_feature_conv = nn.Conv2d(self.out_channel, self.attention_dim, + kernel_size=params['attention']['word_conv_kernel'], + padding=params['attention']['word_conv_kernel']//2) + + self.word_state_weight = nn.Linear(self.hidden_size, self.hidden_size) + self.word_embedding_weight = nn.Linear(self.input_size, self.hidden_size) + self.word_context_weight = nn.Linear(self.out_channel, self.hidden_size) + # self.counting_context_weight = nn.Linear(self.counting_num, self.hidden_size) + self.word_convert = nn.Linear(self.hidden_size, self.word_num) + + if params['dropout']: + self.dropout = nn.Dropout(params['dropout_ratio']) + + def forward(self, cnn_features, labels, images_mask, label_mask, is_train=True): + batch_size, num_steps = labels.shape + height, width = cnn_features.shape[2:] + word_probs = torch.zeros((batch_size, num_steps, self.word_num)).to(device=self.device) + images_mask = images_mask[:, :, ::self.ratio, ::self.ratio] + + word_alpha_sum = torch.zeros((batch_size, 1, height, width)).to(device=self.device) + word_alphas = torch.zeros((batch_size, num_steps, height, width)).to(device=self.device) + hidden = self.init_hidden(cnn_features, images_mask) + + cnn_features_trans = self.encoder_feature_conv(cnn_features) + + word_context_vec_list, label_list, word_out_state_list = [], [], [] + if is_train: + for i in range(num_steps): + word_embedding = self.embedding(labels[:, i-1]) if i else self.embedding(torch.ones([batch_size]).long().to(self.device)) + hidden = self.word_input_gru(word_embedding, hidden) + word_context_vec, word_alpha, word_alpha_sum = self.word_attention(cnn_features, cnn_features_trans, hidden, + word_alpha_sum, images_mask) + hidden = self.word_output_gru(word_context_vec, hidden) + + current_state = self.word_state_weight(hidden) + word_weighted_embedding = self.word_embedding_weight(word_embedding) + word_context_weighted = self.word_context_weight(word_context_vec) + + if self.params['dropout']: + word_out_state = self.dropout(current_state + word_weighted_embedding + word_context_weighted) + else: + word_out_state = current_state + word_weighted_embedding + word_context_weighted + + word_prob = self.word_convert(word_out_state) + word_probs[:, i] = word_prob + word_alphas[:, i] = word_alpha + + word_context_vec_list.append(word_context_vec) + label_list.append(labels[:, i]) + word_out_state_list.append(word_out_state) + + else: + word_embedding = self.embedding(torch.ones([batch_size]).long().to(device=self.device)) + for i in range(num_steps): + hidden = self.word_input_gru(word_embedding, hidden) + word_context_vec, word_alpha, word_alpha_sum = self.word_attention(cnn_features, cnn_features_trans, hidden, + word_alpha_sum, images_mask) + hidden = self.word_output_gru(word_context_vec, hidden) + + current_state = self.word_state_weight(hidden) + word_weighted_embedding = self.word_embedding_weight(word_embedding) + word_context_weighted = self.word_context_weight(word_context_vec) + + if self.params['dropout']: + word_out_state = self.dropout(current_state + word_weighted_embedding + word_context_weighted) + else: + word_out_state = current_state + word_weighted_embedding + word_context_weighted + + word_prob = self.word_convert(word_out_state) + _, word = word_prob.max(1) + word_embedding = self.embedding(word) + word_probs[:, i] = word_prob + word_alphas[:, i] = word_alpha + return word_probs, word_alphas, (self.embedding.weight, word_context_vec_list, word_out_state_list, label_list) + + def init_hidden(self, features, feature_mask): + average = (features * feature_mask).sum(-1).sum(-1) / feature_mask.sum(-1).sum(-1) + average = self.init_weight(average) + return torch.tanh(average) + + diff --git a/models/decoder/decoder_v3.py b/models/decoder/decoder_v3.py new file mode 100644 index 0000000..e6cee0f --- /dev/null +++ b/models/decoder/decoder_v3.py @@ -0,0 +1,134 @@ +import torch +import torch.nn as nn +from models.decoder.attention import Attention +import models + +class Decoder_v1(nn.Module): + def __init__(self, params): + super(Decoder_v1, self).__init__() + self.params = params + self.input_size = params['decoder']['input_size'] + self.hidden_size = params['decoder']['hidden_size'] + self.out_channel = params['encoder']['out_channel'] + self.attention_dim = params['attention']['attention_dim'] + self.dropout_prob = params['dropout'] + self.device = params['device'] + self.word_num = params['word_num'] + # self.counting_num = params['counting_decoder']['out_channel'] + + """经过cnn后 长宽与原始尺寸比缩小的比例""" + self.ratio = params['densenet']['ratio'] + + # init hidden state + self.init_weight = nn.Linear(self.out_channel, self.hidden_size) + # word embedding + self.embedding = nn.Embedding(self.word_num, self.input_size) + # word gru + self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size) + self.word_output_gru = nn.GRUCell(self.out_channel, self.hidden_size) + # attention + self.word_attention = Attention(params) + self.encoder_feature_conv = nn.Conv2d(self.out_channel, self.attention_dim, + kernel_size=params['attention']['word_conv_kernel'], + padding=params['attention']['word_conv_kernel']//2) + + self.word_state_weight = nn.Linear(self.hidden_size, self.hidden_size) + self.word_embedding_weight = nn.Linear(self.input_size, self.hidden_size) + self.word_context_weight = nn.Linear(self.out_channel, self.hidden_size) + # self.counting_context_weight = nn.Linear(self.counting_num, self.hidden_size) + self.word_convert = nn.Linear(self.hidden_size, self.word_num) + + if params['dropout']: + self.dropout = nn.Dropout(params['dropout_ratio']) + + if "counting_decoder" in self.params and self.params['counting_decoder']['use_flag']: + self.counting_context_weight = nn.Linear(self.params['counting_decoder']['out_channel'], self.params['decoder']['hidden_size']) + self.counting_decoder1 = getattr(models, "counting_decoder")(self.params['counting_decoder']['in_channel'], self.params['counting_decoder']['out_channel'], 3) + self.counting_decoder2 = getattr(models, "counting_decoder")(self.params['counting_decoder']['in_channel'], self.params['counting_decoder']['out_channel'], 5) + + self.counting_loss = nn.SmoothL1Loss(reduction='mean') + + def counting(self, cnn_features, images_mask, counting_labels): + counting_mask = images_mask[:, :, ::self.ratio, ::self.ratio] + + counting_preds1, _ = self.counting_decoder1(cnn_features, counting_mask) + counting_preds2, _ = self.counting_decoder2(cnn_features, counting_mask) + counting_preds = (counting_preds1 + counting_preds2) / 2 + counting_loss = self.counting_loss(counting_preds1, counting_labels) + self.counting_loss(counting_preds2, counting_labels) \ + + self.counting_loss(counting_preds, counting_labels) + + return counting_loss, counting_preds + + def forward(self, cnn_features, labels, images_mask, label_mask, counting_labels=None, is_train=True): + counting_loss, counting_context_weighted = 0., 0. + if "counting_decoder" in self.params and self.params['counting_decoder']['use_flag']: + counting_loss, counting_preds = self.counting(cnn_features, images_mask, counting_labels) + counting_context_weighted = self.counting_context_weight(counting_preds) + + batch_size, num_steps = labels.shape + height, width = cnn_features.shape[2:] + word_probs = torch.zeros((batch_size, num_steps, self.word_num)).to(device=self.device) + images_mask = images_mask[:, :, ::self.ratio, ::self.ratio] + + word_alpha_sum = torch.zeros((batch_size, 1, height, width)).to(device=self.device) + word_alphas = torch.zeros((batch_size, num_steps, height, width)).to(device=self.device) + hidden = self.init_hidden(cnn_features, images_mask) + + cnn_features_trans = self.encoder_feature_conv(cnn_features) + + word_context_vec_list, label_list, word_out_state_list = [], [], [] + if is_train: + for i in range(num_steps): + word_embedding = self.embedding(labels[:, i-1]) if i else self.embedding(torch.ones([batch_size]).long().to(self.device)) + hidden = self.word_input_gru(word_embedding, hidden) + word_context_vec, word_alpha, word_alpha_sum = self.word_attention(cnn_features, cnn_features_trans, hidden, + word_alpha_sum, images_mask) + hidden = self.word_output_gru(word_context_vec, hidden) + + current_state = self.word_state_weight(hidden) + word_weighted_embedding = self.word_embedding_weight(word_embedding) + word_context_weighted = self.word_context_weight(word_context_vec) + + if self.params['dropout']: + word_out_state = self.dropout(current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted) + else: + word_out_state = current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted + + word_prob = self.word_convert(word_out_state) + word_probs[:, i] = word_prob + word_alphas[:, i] = word_alpha + + word_context_vec_list.append(word_context_vec) + label_list.append(labels[:, i]) + word_out_state_list.append(word_out_state) + + else: + word_embedding = self.embedding(torch.ones([batch_size]).long().to(device=self.device)) + for i in range(num_steps): + hidden = self.word_input_gru(word_embedding, hidden) + word_context_vec, word_alpha, word_alpha_sum = self.word_attention(cnn_features, cnn_features_trans, hidden, + word_alpha_sum, images_mask) + hidden = self.word_output_gru(word_context_vec, hidden) + + current_state = self.word_state_weight(hidden) + word_weighted_embedding = self.word_embedding_weight(word_embedding) + word_context_weighted = self.word_context_weight(word_context_vec) + + if self.params['dropout']: + word_out_state = self.dropout(current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted) + else: + word_out_state = current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted + + word_prob = self.word_convert(word_out_state) + _, word = word_prob.max(1) + word_embedding = self.embedding(word) + word_probs[:, i] = word_prob + word_alphas[:, i] = word_alpha + return word_probs, word_alphas, (self.embedding.weight, word_context_vec_list, word_out_state_list, label_list, counting_loss) + + def init_hidden(self, features, feature_mask): + average = (features * feature_mask).sum(-1).sum(-1) / feature_mask.sum(-1).sum(-1) + average = self.init_weight(average) + return torch.tanh(average) + + diff --git a/models/infer_model.py b/models/infer_model.py deleted file mode 100644 index e285080..0000000 --- a/models/infer_model.py +++ /dev/null @@ -1,171 +0,0 @@ -import os -import cv2 -import torch -import torch.nn as nn -import math -import numpy as np - - -from models.densenet import DenseNet -from models.decoder.attention import Attention -from utils import draw_attention_map, draw_counting_map - - -class Inference(nn.Module): - def __init__(self, params=None, draw_map=False): - super(Inference, self).__init__() - self.params = params - self.draw_map = draw_map - self.use_label_mask = params['use_label_mask'] - self.encoder = DenseNet(params=self.params) - self.in_channel = params['counting_decoder']['in_channel'] - self.out_channel = params['counting_decoder']['out_channel'] - self.device = params['device'] - self.decoder = decoder_dict[params['decoder']['net']](params=self.params) - - """经过cnn后 长宽与原始尺寸比缩小的比例""" - self.ratio = params['densenet']['ratio'] - - with open(params['word_path']) as f: - words = f.readlines() - print(f'共 {len(words)} 类符号。') - self.words_index_dict = {i: words[i].strip() for i in range(len(words))} - self.cal_mae = nn.L1Loss(reduction='mean') - self.cal_mse = nn.MSELoss(reduction='mean') - self.ignore_symbos = self.params['words'].encode(['','','{','}','_','^']) - - def forward(self, images, labels, name, is_train=False): - cnn_features = self.encoder(images) - batch_size, _, height, width = cnn_features.shape - counting_preds1, counting_maps1 = self.counting_decoder1(cnn_features, None) - counting_preds2, counting_maps2 = self.counting_decoder2(cnn_features, None) - counting_preds = (counting_preds1 + counting_preds2) / 2 - counting_maps = (counting_maps1 + counting_maps2) / 2 - - mae = self.cal_mae(counting_preds, gen_counting_label(labels, self.out_channel, True, ignore_symbos=self.ignore_symbos)).item() - mse = math.sqrt(self.cal_mse(counting_preds, gen_counting_label(labels, self.out_channel, True, ignore_symbos=self.ignore_symbos)).item()) - - word_probs, alphas = self.decoder(cnn_features, counting_preds, counting_map=counting_maps, is_train=is_train) - - if self.params['decoder']['net'] in ['AttFusionLocDecoder_v2','AttFusionLocDecoder_v3']: - word_alphas, loc_alphas = alphas - else: - word_alphas = alphas - - if self.draw_map: - if not os.path.exists(os.path.join(self.params['attention_map_vis_path'], name)): - os.makedirs(os.path.join(self.params['attention_map_vis_path'], name), exist_ok=True) - if not os.path.exists(os.path.join(self.params['counting_map_vis_path'], name)): - os.makedirs(os.path.join(self.params['counting_map_vis_path'], name), exist_ok=True) - for i in range(images.shape[0]): - img = 255 - images[i][0].detach().cpu().numpy() * 255 - # draw attention_map and loc attention_map - for step in range(len(word_probs)): - word_atten = word_alphas[step][0].detach().cpu().numpy() - word_heatmap = draw_attention_map(img, word_atten) - - if self.params['decoder']['net'] in ['AttFusionLocDecoder_v2', 'AttFusionLocDecoder_v3']: - loc_atten = loc_alphas[step][0].detach().cpu().numpy() - loc_heatmap = draw_attention_map(img, loc_atten) - - h, w, c = word_heatmap.shape - edge = np.ones([10,w,c]) * 255 - - out = np.concatenate([word_heatmap, edge, loc_heatmap], axis=0) - else: - out = word_heatmap - - cv2.imwrite(os.path.join(self.params['attention_map_vis_path'], name, f'{step}.jpg'), out) - - # draw counting_map - for idx in range(self.out_channel): - counting_map = counting_maps[0].permute(1,2,0)[:,:,idx].detach().cpu() - counting_heatmap = draw_counting_map(img, counting_map) - img_name = 'symbol_' + self.words_index_dict[idx] + '_map.jpg' - cv2.imwrite(os.path.join(self.params['counting_map_vis_path'], name, img_name), counting_heatmap) - - return word_probs, word_alphas, mae, mse - - -class AttDecoder(nn.Module): - def __init__(self, params): - super(AttDecoder, self).__init__() - self.params = params - self.input_size = params['decoder']['input_size'] - self.hidden_size = params['decoder']['hidden_size'] - self.out_channel = params['encoder']['out_channel'] - self.attention_dim = params['attention']['attention_dim'] - self.dropout_prob = params['dropout'] - self.device = params['device'] - self.word_num = params['word_num'] - self.ratio = params['densenet']['ratio'] - - self.init_weight = nn.Linear(self.out_channel, self.hidden_size) - self.embedding = nn.Embedding(self.word_num, self.input_size) - self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size) - self.encoder_feature_conv = nn.Conv2d(self.out_channel, self.attention_dim, kernel_size=1) - self.word_attention = Attention(params) - - self.word_state_weight = nn.Linear(self.hidden_size, self.hidden_size) - self.word_embedding_weight = nn.Linear(self.input_size, self.hidden_size) - self.word_context_weight = nn.Linear(self.out_channel, self.hidden_size) - self.counting_context_weight = nn.Linear(self.word_num, self.hidden_size) - self.word_convert = nn.Linear(self.hidden_size, self.word_num) - - if params['dropout']: - self.dropout = nn.Dropout(params['dropout_ratio']) - - def forward(self, cnn_features, counting_preds, is_train=False): - batch_size, _, height, width = cnn_features.shape - image_mask = torch.ones((batch_size, 1, height, width)).to(self.device) - - cnn_features_trans = self.encoder_feature_conv(cnn_features) - position_embedding = PositionEmbeddingSine(256, normalize=True) - pos = position_embedding(cnn_features_trans, image_mask[:,0,:,:]) - cnn_features_trans = cnn_features_trans + pos - - word_alpha_sum = torch.zeros((batch_size, 1, height, width)).to(device=self.device) - hidden = self.init_hidden(cnn_features, image_mask) - word_embedding = self.embedding(torch.ones([batch_size]).long().to(device=self.device)) - counting_context_weighted = self.counting_context_weight(counting_preds) - word_probs = [] - word_alphas = [] - - i = 0 - while i < 200: - hidden = self.word_input_gru(word_embedding, hidden) - word_context_vec, word_alpha, word_alpha_sum = self.word_attention(cnn_features, cnn_features_trans, hidden, - word_alpha_sum, image_mask) - - current_state = self.word_state_weight(hidden) - word_weighted_embedding = self.word_embedding_weight(word_embedding) - word_context_weighted = self.word_context_weight(word_context_vec) - - if self.params['dropout']: - word_out_state = self.dropout(current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted) - else: - word_out_state = current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted - - word_prob = self.word_convert(word_out_state) - _, word = word_prob.max(1) - word_embedding = self.embedding(word) - if word.item() == 0: - return word_probs, word_alphas - word_alphas.append(word_alpha) - word_probs.append(word) - i+=1 - return word_probs, word_alphas - - def init_hidden(self, features, feature_mask): - average = (features * feature_mask).sum(-1).sum(-1) / feature_mask.sum(-1).sum(-1) - average = self.init_weight(average) - return torch.tanh(average) - - -decoder_dict = { - 'AttDecoder': AttDecoder, - # 'AttFusionLocDecoder_v1': AttFusionLocDecoder_v1, - # 'AttFusionLocDecoder_v2': AttFusionLocDecoder_v2, - # 'AttFusionLocDecoder_v3': AttFusionLocDecoder_v3, - # 'AttFusionLocDecoder_v4': AttFusionLocDecoder_v4 -} \ No newline at end of file diff --git a/train.py b/train.py index c88f8aa..d1ebac0 100644 --- a/train.py +++ b/train.py @@ -55,7 +55,8 @@ if params['finetune']: print('加载预训练模型权重') print(f'预训练权重路径: {params["checkpoint"]}') - load_checkpoint(model, optimizer, params['checkpoint']) + if params["checkpoint"]: + load_checkpoint(model, optimizer, params['checkpoint']) if not args.check: if not os.path.exists(os.path.join(params['checkpoint_dir'], model.name)): @@ -67,12 +68,12 @@ min_score = 0 min_step = 0 rate_2014, rate_2016, rate_2019 = 0.55, 0.54, 0.55 -rate_2014, rate_2016, rate_2019 = 0.50, 0.50, 0.50 +rate_2014, rate_2016, rate_2019 = 0.54, 0.54, 0.54 # init_epoch = 0 if not params['finetune'] else int(params['checkpoint'].split('_')[-1].split('.')[0]) if args.val: epoch = 1 - model.load_state_dict(torch.load("checkpoints/v1_l1-loss_2022-11-28-23-38_decoder-Decoder_v1/2016_v1_l1-loss_2022-11-28-23-38_decoder-Decoder_v1_WordRate-0.9057_ExpRate-0.5405_182.pth", map_location="cpu")['model']) + # model.load_state_dict(torch.load("checkpoints/v1_l2-loss_2022-11-30-09-57_decoder-Decoder_v1/2016_v1_l2-loss_2022-11-30-09-57_decoder-Decoder_v1_WordRate-0.9094_ExpRate-0.5562_190.pth", map_location="cpu")['model']) print() eval_loss, eval_word_score, eval_expRate = eval(params, model, epoch, eval_loader_14) print(f'2014 Epoch: {epoch + 1} loss: {eval_loss:.4f} word score: {eval_word_score:.4f} ExpRate: {eval_expRate:.4f}') diff --git a/training.py b/training.py index e97c234..45a3d64 100644 --- a/training.py +++ b/training.py @@ -1,26 +1,45 @@ import torch from tqdm import tqdm from utils import update_lr, Meter, cal_score - - -def train(params, model, optimizer, epoch, train_loader, writer=None): +from torch import nn +from copy import deepcopy + +shadow_model = {} + +def finetune_part(model: nn.Module, name): + for k, v in model.named_parameters(): + if name in k: + v.requires_grad = True + else: + v.requires_grad = False + if k in shadow_model: + assert torch.equal(shadow_model[k], v), "find params change!" + else: + shadow_model[k] = deepcopy(v) + return + +def train(params, model, optimizer:torch.optim.Optimizer, epoch, train_loader, writer=None): model.train() device = params['device'] loss_meter = Meter() word_right, exp_right, length, cal_num = 0, 0, 0, 0 + if params['finetune']: + finetune_part(model, 'counting') + with tqdm(train_loader, total=len(train_loader)//params['train_parts']) as pbar: - for batch_idx, (images, image_masks, labels, label_masks) in enumerate(pbar): - images, image_masks, labels, label_masks = images.to(device), image_masks.to( - device), labels.to(device), label_masks.to(device) + for batch_idx, (images, image_masks, labels, label_masks, matrix, counting_labels) in enumerate(pbar): + images, image_masks, labels, label_masks, matrix, counting_labels = \ + images.to(device, non_blocking=True), image_masks.to(device, non_blocking=True), \ + labels.to(device, non_blocking=True), label_masks.to(device, non_blocking=True), \ + matrix.to(device, non_blocking=True), counting_labels.to(device, non_blocking=True) batch, time = labels.shape[:2] if not 'lr_decay' in params or params['lr_decay'] == 'cosine': update_lr(optimizer, epoch, batch_idx, len(train_loader), params['epochs'], params['lr']) optimizer.zero_grad() - probs, loss = model(images, image_masks, labels, label_masks) - word_loss, sim_loss = loss - loss = word_loss + sim_loss - + probs, loss = model(images, image_masks, labels, label_masks, matrix=matrix, counting_labels=counting_labels) + word_loss, sim_loss, context_loss, word_state_loss, counting_loss = loss + loss = word_loss + sim_loss + context_loss + word_state_loss + counting_loss loss.backward() if params['gradient_clip']: @@ -33,17 +52,24 @@ def train(params, model, optimizer, epoch, train_loader, writer=None): exp_right = exp_right + ExpRate * batch length = length + time cal_num = cal_num + batch - + + if isinstance(sim_loss, torch.Tensor): + sim_loss = sim_loss.item() + if writer: current_step = epoch * len(train_loader) // params['train_parts'] + batch_idx + 1 writer.add_scalar('train/loss', loss.item(), current_step) + writer.add_scalar('train/sim', sim_loss, current_step) + writer.add_scalar('train/counting', counting_loss, current_step) + writer.add_scalar('train/context', context_loss, current_step) + writer.add_scalar('train/word', word_state_loss, current_step) writer.add_scalar('train/WordRate', wordRate, current_step) writer.add_scalar('train/ExpRate', ExpRate, current_step) writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], current_step) - - string = f'{epoch + 1} word_loss:{word_loss.item():.3f} sim_loss: {sim_loss.item():.3f} ' + + string = f'{epoch + 1} word_loss:{word_loss.item():.3f} sim_loss: {sim_loss:.3f} ' string += f'WRate:{word_right / length:.3f} ERate:{exp_right / cal_num:.3f}' - + pbar.set_description(string) if batch_idx >= len(train_loader) // params['train_parts']: break @@ -62,13 +88,13 @@ def eval(params, model, epoch, eval_loader, writer=None): word_right, exp_right, length, cal_num = 0, 0, 0, 0 with tqdm(eval_loader, total=len(eval_loader)//params['valid_parts']) as pbar, torch.no_grad(): - for batch_idx, (images, image_masks, labels, label_masks) in enumerate(pbar): - images, image_masks, labels, label_masks = images.to(device), image_masks.to( - device), labels.to(device), label_masks.to(device) + for batch_idx, (images, image_masks, labels, label_masks, _, counting_labels) in enumerate(pbar): + images, image_masks, labels, label_masks, counting_labels = images.to(device), image_masks.to( + device), labels.to(device), label_masks.to(device), counting_labels.to(device) batch, time = labels.shape[:2] - probs, loss = model(images, image_masks, labels, label_masks, is_train=False) + probs, loss = model(images, image_masks, labels, label_masks, counting_labels=counting_labels, is_train=False) - word_loss, sim_loss = loss + word_loss, sim_loss, _, _, counting_loss = loss loss = word_loss + sim_loss loss_meter.add(loss.item()) @@ -78,6 +104,9 @@ def eval(params, model, epoch, eval_loader, writer=None): length = length + time cal_num = cal_num + batch + if isinstance(sim_loss, torch.Tensor): + sim_loss = sim_loss.item() + if writer: current_step = epoch * len(eval_loader)//params['valid_parts'] + batch_idx + 1 writer.add_scalar('eval/word_loss', word_loss.item(), current_step) @@ -85,7 +114,7 @@ def eval(params, model, epoch, eval_loader, writer=None): writer.add_scalar('eval/WordRate', wordRate, current_step) writer.add_scalar('eval/ExpRate', ExpRate, current_step) - pbar.set_description(f'{epoch+1} word_loss:{word_loss.item():.4f} sim_loss:{sim_loss.item():.4f}' + pbar.set_description(f'{epoch+1} word_loss:{word_loss.item():.4f} sim_loss:{sim_loss:.4f}' f' WRate:{word_right / length:.4f} ERate:{exp_right / cal_num:.4f}') if batch_idx >= len(eval_loader) // params['valid_parts']: break