diff --git a/.gitignore b/.gitignore
index 9790f49..5c836c5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,4 +5,8 @@ logs
datasets
.DS_Store
-*.log
+.vscode
+logs_v2
+heatmap
+184_checkpoints
+.ipynb_checkpoints
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..f288702
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,674 @@
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU General Public License is a free, copyleft license for
+software and other kinds of works.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+the GNU General Public License is intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users. We, the Free Software Foundation, use the
+GNU General Public License for most of our software; it applies also to
+any other work released this way by its authors. You can apply it to
+your programs, too.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ To protect your rights, we need to prevent others from denying you
+these rights or asking you to surrender the rights. Therefore, you have
+certain responsibilities if you distribute copies of the software, or if
+you modify it: responsibilities to respect the freedom of others.
+
+ For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must pass on to the recipients the same
+freedoms that you received. You must make sure that they, too, receive
+or can get the source code. And you must show them these terms so they
+know their rights.
+
+ Developers that use the GNU GPL protect your rights with two steps:
+(1) assert copyright on the software, and (2) offer you this License
+giving you legal permission to copy, distribute and/or modify it.
+
+ For the developers' and authors' protection, the GPL clearly explains
+that there is no warranty for this free software. For both users' and
+authors' sake, the GPL requires that modified versions be marked as
+changed, so that their problems will not be attributed erroneously to
+authors of previous versions.
+
+ Some devices are designed to deny users access to install or run
+modified versions of the software inside them, although the manufacturer
+can do so. This is fundamentally incompatible with the aim of
+protecting users' freedom to change the software. The systematic
+pattern of such abuse occurs in the area of products for individuals to
+use, which is precisely where it is most unacceptable. Therefore, we
+have designed this version of the GPL to prohibit the practice for those
+products. If such problems arise substantially in other domains, we
+stand ready to extend this provision to those domains in future versions
+of the GPL, as needed to protect the freedom of users.
+
+ Finally, every program is threatened constantly by software patents.
+States should not allow patents to restrict development and use of
+software on general-purpose computers, but in those that do, we wish to
+avoid the special danger that patents applied to a free program could
+make it effectively proprietary. To prevent this, the GPL assures that
+patents cannot be used to render the program non-free.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Use with the GNU Affero General Public License.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU Affero General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the special requirements of the GNU Affero General Public License,
+section 13, concerning interaction through a network will apply to the
+combination as such.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU General Public License from time to time. Such new versions will
+be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If the program does terminal interaction, make it output a short
+notice like this when it starts in an interactive mode:
+
+ Copyright (C)
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands `show w' and `show c' should show the appropriate
+parts of the General Public License. Of course, your program's commands
+might be different; for a GUI interface, you would use an "about box".
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU GPL, see
+.
+
+ The GNU General Public License does not permit incorporating your program
+into proprietary programs. If your program is a subroutine library, you
+may consider it more useful to permit linking proprietary applications with
+the library. If this is what you want to do, use the GNU Lesser General
+Public License instead of this License. But first, please read
+.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..5cdf258
--- /dev/null
+++ b/README.md
@@ -0,0 +1,5 @@
+# SAM
+[Semantic Graph Representation Learning for Handwritten Mathematical Expression Recognition (ICDAR 2023)](https://link.springer.com/chapter/10.1007/978-3-031-41676-7_9)
+
+The code will be released after the ICDAR2023 meeting!
+
diff --git a/config.yaml b/config_v2.yaml
similarity index 69%
rename from config.yaml
rename to config_v2.yaml
index c067eea..789da00 100644
--- a/config.yaml
+++ b/config_v2.yaml
@@ -1,5 +1,17 @@
# 实验名称
-experiment: "v1_l2-loss" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_ori_train_with_gpus1_bs8_lr1" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_gpus1_bs8_lr1" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_gpus1_bs8_lr1_ddp" # "ori_CAN_in_full_crohme"
+experiment: "can-l2-context-word" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_ori_train_with_gpus1_bs8_lr1" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_gpus1_bs8_lr1" # "ori_CAN_with_counting_predicted_class_info_in_full_crohme_with_gpus1_bs8_lr1_ddp" # "ori_CAN_in_full_crohme"
+
+sim_loss:
+ type: l2
+ use_flag: True
+context_loss: False
+word_state_loss: False
+
+
+counting_decoder:
+ use_flag: True
+ in_channel: 684
+ out_channel: 111
# 随机种子
seed: 20211024
@@ -7,10 +19,10 @@ seed: 20211024
# 训练参数
epochs: 200
batch_size: 8 # 8
-workers: 0 # 0
+workers: 5 # 0
train_parts: 1
valid_parts: 1
-valid_start: 20 # 1000000000
+valid_start: 100 # 1000000000
save_start: 0 # 220
optimizer: Adadelta
@@ -22,9 +34,9 @@ eps: 1e-6
weight_decay: 1e-4
beta: 0.9
-output_counting_feature: False
-output_channel_attn_feature: False
-counting_loss_ratio: 1
+# output_counting_feature: False
+# output_channel_attn_feature: False
+# counting_loss_ratio: 1
dropout: True
dropout_ratio: 0.5
@@ -69,7 +81,7 @@ encoder:
out_channel: 684
decoder:
- net: Decoder_v1
+ net: Decoder_v3
cell: 'GRU'
input_size: 256
hidden_size: 256
@@ -79,15 +91,13 @@ attention:
attention_dim: 512
word_conv_kernel: 1
-sim_loss:
- type: l2
whiten_type: None
max_step: 256
optimizer_save: False
-finetune: False
+finetune: True
checkpoint_dir: 'checkpoints'
checkpoint: ""
-log_dir: 'logs'
+log_dir: 'logs_v2'
diff --git a/counting_utils.py b/counting_utils.py
new file mode 100644
index 0000000..6219123
--- /dev/null
+++ b/counting_utils.py
@@ -0,0 +1,22 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import matplotlib.pyplot as plt
+import os
+
+def gen_counting_label(labels, channel, tag):
+ b, t = labels.size()
+ counting_labels = torch.zeros((b, channel))
+ if tag:
+ ignore = [0, 1, 107, 108, 109, 110]
+ else:
+ ignore = []
+ for i in range(b):
+ for j in range(t):
+ k = labels[i][j]
+ if k in ignore:
+ continue
+ else:
+ counting_labels[i][k] += 1
+ return counting_labels.detach()
diff --git a/dataset.py b/dataset.py
index 0c2d50e..fdca8d5 100644
--- a/dataset.py
+++ b/dataset.py
@@ -2,6 +2,7 @@
import time
import pickle as pkl
from torch.utils.data import DataLoader, Dataset, RandomSampler, DistributedSampler
+from counting_utils import gen_counting_label
class HMERDataset(Dataset):
@@ -33,6 +34,10 @@ def __init__(self, params, image_path, label_path, words, is_train=True, use_aug
self.reverse_color = self.params['data_process']['reverse_color'] if 'data_process' in params else False
self.equal_range = self.params['data_process']['equal_range'] if 'data_process' in params else False
+ with open(self.params['matrix_path'], 'rb') as f:
+ matrix = pkl.load(f)
+ self.matrix = torch.Tensor(matrix)
+
def __len__(self):
# assert len(self.images) == len(self.labels)
return len(self.labels)
@@ -55,7 +60,28 @@ def __getitem__(self, idx):
words = self.words.encode(labels) + [0]
words = torch.LongTensor(words)
return image, words
-
+
+ def gen_matrix(self, labels):
+ (B, L), device = labels.shape, labels.device
+ matrix = []
+ for i in range(B):
+ _L = []
+ label = labels[i]
+ for x in range(L):
+ _T = []
+ for y in range(L):
+ if x == y:
+ _T.append(1.)
+ else:
+ if label[x] == label[y] or label[x] == 0 or label[y] == 0:
+ _T.append(0.)
+ else:
+ _T.append(self.matrix[label[x], label[y]])
+ _L.append(_T)
+ matrix.append(_L)
+ matrix = torch.tensor(matrix).to(device)
+ return matrix.detach()
+
def collate_fn(self, batch_images):
max_width, max_height, max_length = 0, 0, 0
batch, channel = len(batch_images), batch_images[0][0].shape[0]
@@ -80,7 +106,9 @@ def collate_fn(self, batch_images):
l = proper_items[i][1].shape[0]
labels[i][:l] = proper_items[i][1]
labels_masks[i][:l] = 1
- return images, image_masks, labels, labels_masks
+ matrix = self.gen_matrix(labels)
+ counting_labels = gen_counting_label(labels, self.params['counting_decoder']['out_channel'], True)
+ return images, image_masks, labels, labels_masks, matrix, counting_labels
def get_crohme_dataset(params):
diff --git a/inference.py b/inference.py
deleted file mode 100644
index 2eda079..0000000
--- a/inference.py
+++ /dev/null
@@ -1,122 +0,0 @@
-import os
-import cv2
-import argparse
-import torch
-import torch.nn as nn
-import json
-import pickle as pkl
-from tqdm import tqdm
-import time
-
-from utils import load_config, load_checkpoint, compute_edit_distance
-from models.infer_model import Inference
-from dataset import Words
-
-parser = argparse.ArgumentParser(description='model testing')
-parser.add_argument('--dataset', default='CROHME', type=str, help='数据集名称')
-parser.add_argument('--config', default='CROHME', type=str, help='数据集名称')
-parser.add_argument('--image_path', default='/liuzhuang7/CROHME/19_test_images.pkl', type=str, help='测试image路径')
-parser.add_argument('--label_path', default='/liuzhuang7/CROHME/19_test_labels.txt', type=str, help='测试label路径')
-parser.add_argument('--word_path', default='/liuzhuang7/CROHME/words_dict.txt', type=str, help='测试dict路径')
-parser.add_argument('--model_path', default='', type=str, help='path of trained model')
-
-parser.add_argument('--draw_map', action='store_true')
-args = parser.parse_args()
-
-if not args.dataset:
- print('请提供数据集名称')
- exit(-1)
-
-"""加载config文件"""
-params = load_config(args.config)
-
-# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-params['device'] = device
-words = Words(args.word_path)
-params['word_num'] = len(words)
-params['words'] = words
-if args.model_path != '':
- params['checkpoint'] = args.model_path
-print(params['checkpoint'])
-
-if 'use_label_mask' not in params:
- params['use_label_mask'] = False
-print(params['decoder']['net'])
-model = Inference(params, draw_map=args.draw_map)
-model = model.to(device)
-
-load_checkpoint(model, None, params['checkpoint'])
-model.eval()
-
-with open(args.image_path, 'rb') as f:
- images = pkl.load(f)
-
-with open(args.label_path) as f:
- lines = f.readlines()
-
-line_right = 0
-e1, e2, e3 = 0, 0, 0
-bad_case = {}
-model_time = 0
-mae_sum, mse_sum = 0, 0
-num = 1
-with tqdm(lines) as pbar, torch.no_grad():
- for line in pbar:
- name, *labels = line.split()
- # if name!='20_em_33':
- # continue
- name = name.split('.')[0] if name.endswith('jpg') else name
- input_labels = labels
- labels = ' '.join(labels)
- img = images[name]
- if params['data_process']['reverse_color']:
- img = 255 - img
- if params['data_process']['equal_range']:
- img = (img / 255 - 0.5) * 2
- else:
- img = img / 255
- img = torch.Tensor(img)
- img = img.unsqueeze(0).unsqueeze(0)
- img = img.to(device)
- a = time.time()
-
- input_labels = words.encode(input_labels)
- input_labels = torch.LongTensor(input_labels)
- input_labels = input_labels.unsqueeze(0).to(device)
-
- probs, _, mae, mse = model(img, input_labels, os.path.join(params['decoder']['net'], name))
- mae_sum += mae
- mse_sum += mse
- model_time += (time.time() - a)
-
- prediction = words.decode(probs)
- if prediction == labels:
- line_right += 1
- else:
- bad_case[name] = {
- 'label': labels,
- 'predi': prediction
- }
- # print(name, prediction, labels)
-
- distance = compute_edit_distance(prediction, labels)
- if distance <= 1:
- e1 += 1
- if distance <= 2:
- e2 += 1
- if distance <= 3:
- e3 += 1
- pbar.set_description(f'ExpRate: {line_right/num: .4f}')
- num += 1
-
-print(f'model time: {model_time}')
-print(f'ExpRate: {line_right / len(lines)}')
-print(f'mae: {mae_sum / len(lines)}')
-print(f'mse: {mse_sum / len(lines)}')
-print(f'e1: {e1 / len(lines)}')
-print(f'e2: {e2 / len(lines)}')
-print(f'e3: {e3 / len(lines)}')
-
-with open(f'{params["decoder"]["net"]}_bad_case.json','w') as f:
- json.dump(bad_case,f,ensure_ascii=False)
diff --git a/models/__init__.py b/models/__init__.py
index 6dfb1c3..af29e9e 100644
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -1 +1,4 @@
-from models.decoder.decoder_v1 import Decoder_v1
\ No newline at end of file
+from models.decoder.decoder_v1 import Decoder_v1
+from models.decoder.decoder_v2 import Decoder_v1 as Decoder_v2
+from models.decoder.decoder_v3 import Decoder_v1 as Decoder_v3
+from models.counting import CountingDecoder as counting_decoder
\ No newline at end of file
diff --git a/models/backbone.py b/models/backbone.py
index 9c4fc81..61a1b32 100644
--- a/models/backbone.py
+++ b/models/backbone.py
@@ -6,9 +6,11 @@
import models
from models.densenet import DenseNet
+from einops.layers.torch import Rearrange
+from traceback import print_exc
class Model(nn.Module):
- def __init__(self, params=None):
+ def __init__(self, params={}):
super(Model, self).__init__()
self.params = params
@@ -21,11 +23,6 @@ def __init__(self, params=None):
self.use_label_mask = params['use_label_mask']
self.encoder = DenseNet(params=self.params)
- # self.in_channel = params['counting_decoder']['in_channel']
- # self.out_channel = params['counting_decoder']['out_channel']
-
- # self.output_counting_feature = params['output_counting_feature'] if 'output_counting_feature' in params else False
- # self.channel_attn_feature = params['output_channel_attn_feature'] if 'output_channel_attn_feature' in params else False
self.decoder = getattr(models, params['decoder']['net'])(params=self.params)
self.cross = nn.CrossEntropyLoss(reduction='none') if self.use_label_mask else nn.CrossEntropyLoss()
@@ -35,19 +32,62 @@ def __init__(self, params=None):
"""经过cnn后 长宽与原始尺寸比缩小的比例"""
self.ratio = params['densenet']['ratio']
- def forward(self, images, images_mask, labels, labels_mask, is_train=True):
+ if self.params['context_loss'] or self.params['word_state_loss']:
+ self.cma_context = nn.Sequential(
+ nn.Linear(params['encoder']['out_channel'], params['decoder']['input_size']),
+ Rearrange("b l h->b h l"),
+ nn.BatchNorm1d(params['decoder']['input_size']),
+ Rearrange("b h l->b l h"),
+ nn.ReLU()
+ )
+ self.cma_word = nn.Sequential(
+ nn.Linear(params['decoder']['input_size'], params['decoder']['input_size']),
+ Rearrange("b l h->b h l"),
+ nn.BatchNorm1d(params['decoder']['input_size']),
+ Rearrange("b h l->b l h"),
+ nn.ReLU()
+ )
+
+ def forward(self, images, images_mask, labels, labels_mask, matrix=None, counting_labels=None, is_train=True):
cnn_features = self.encoder(images)
- word_probs, word_alphas, embedding = self.decoder(cnn_features, labels, images_mask, labels_mask, is_train=is_train)
-
+ word_probs, word_alphas, embedding = self.decoder(cnn_features, labels, images_mask, labels_mask, counting_labels=counting_labels, is_train=is_train)
+
+ context_loss, word_state_loss, word_sim_loss, counting_loss = 0, 0, 0, 0
+ embedding, word_context_vec_list, word_out_state_list, _, counting_loss = embedding
+ if self.params['context_loss'] or self.params['word_state_loss'] and is_train:
+ if 'context_loss' in self.params and self.params['context_loss']:
+ word_context_vec_list = torch.stack(word_context_vec_list, 1)
+ context_embedding = self.cma_context(word_context_vec_list)
+ context_loss = self.cal_cam_loss_v2(context_embedding, labels, matrix)
+ if 'word_state_loss' in self.params and self.params['word_state_loss']:
+ word_out_state_list = torch.stack(word_out_state_list, 1)
+ word_state_embedding = self.cma_word(word_out_state_list)
+ word_state_loss = self.cal_cam_loss_v2(word_state_embedding, labels, matrix)
+
word_loss = self.cross(word_probs.contiguous().view(-1, word_probs.shape[-1]), labels.view(-1))
word_average_loss = (word_loss * labels_mask.view(-1)).sum() / (labels_mask.sum() + 1e-10) if self.use_label_mask else word_loss
- word_sim_loss = self.cal_word_similarity(embedding)
+ if 'sim_loss' in self.params and self.params['sim_loss']['use_flag']:
+ word_sim_loss = self.cal_word_similarity(embedding)
- return word_probs, (word_average_loss, word_sim_loss)
+ return word_probs, (word_average_loss, word_sim_loss, context_loss, word_state_loss, counting_loss)
+ def cal_cam_loss_v2(self, word_embedding, labels, matrix):
+ (B, L, H), device = word_embedding.shape, word_embedding.device
+
+ W = torch.matmul(word_embedding, word_embedding.transpose(-1, -2)) # B L L
+ denom = torch.matmul(word_embedding.unsqueeze(-2), word_embedding.unsqueeze(-1)).squeeze(-1) ** (0.5)
+ # B L 1 H @ B L H 1 -> B L 1 1
+ cosine = W / (denom @ denom.transpose(-1, -2))
+ sim_mask = matrix != 0
+ if self.sim_loss_type == 'l1':
+ loss = abs((cosine - matrix) * sim_mask)
+ else:
+ loss = (cosine - matrix) ** 2 * sim_mask
+ return loss.sum() / B / (labels != 0).sum()
+
def cal_word_similarity(self, word_embedding):
num = word_embedding @ word_embedding.transpose(1,0)
diff --git a/models/counting.py b/models/counting.py
new file mode 100644
index 0000000..399af43
--- /dev/null
+++ b/models/counting.py
@@ -0,0 +1,45 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ChannelAtt(nn.Module):
+ def __init__(self, channel, reduction):
+ super(ChannelAtt, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel//reduction),
+ nn.ReLU(),
+ nn.Linear(channel//reduction, channel),
+ nn.Sigmoid())
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y
+
+
+class CountingDecoder(nn.Module):
+ def __init__(self, in_channel, out_channel, kernel_size):
+ super(CountingDecoder, self).__init__()
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+ self.trans_layer = nn.Sequential(
+ nn.Conv2d(self.in_channel, 512, kernel_size=kernel_size, padding=kernel_size//2, bias=False),
+ nn.BatchNorm2d(512))
+ self.channel_att = ChannelAtt(512, 16)
+ self.pred_layer = nn.Sequential(
+ nn.Conv2d(512, self.out_channel, kernel_size=1, bias=False),
+ nn.Sigmoid())
+
+ def forward(self, x, mask):
+ b, c, h, w = x.size()
+ x = self.trans_layer(x)
+ x = self.channel_att(x)
+ x = self.pred_layer(x)
+ if mask is not None:
+ x = x * mask
+ x = x.view(b, self.out_channel, -1)
+ x1 = torch.sum(x, dim=-1)
+ return x1, x.view(b, self.out_channel, h, w)
diff --git a/models/decoder/decoder_v2.py b/models/decoder/decoder_v2.py
new file mode 100644
index 0000000..03dcab8
--- /dev/null
+++ b/models/decoder/decoder_v2.py
@@ -0,0 +1,111 @@
+import torch
+import torch.nn as nn
+from models.decoder.attention import Attention
+
+
+class Decoder_v1(nn.Module):
+ def __init__(self, params):
+ super(Decoder_v1, self).__init__()
+ self.params = params
+ self.input_size = params['decoder']['input_size']
+ self.hidden_size = params['decoder']['hidden_size']
+ self.out_channel = params['encoder']['out_channel']
+ self.attention_dim = params['attention']['attention_dim']
+ self.dropout_prob = params['dropout']
+ self.device = params['device']
+ self.word_num = params['word_num']
+ # self.counting_num = params['counting_decoder']['out_channel']
+
+ """经过cnn后 长宽与原始尺寸比缩小的比例"""
+ self.ratio = params['densenet']['ratio']
+
+ # init hidden state
+ self.init_weight = nn.Linear(self.out_channel, self.hidden_size)
+ # word embedding
+ self.embedding = nn.Embedding(self.word_num, self.input_size)
+ # word gru
+ self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size)
+ self.word_output_gru = nn.GRUCell(self.out_channel, self.hidden_size)
+ # attention
+ self.word_attention = Attention(params)
+ self.encoder_feature_conv = nn.Conv2d(self.out_channel, self.attention_dim,
+ kernel_size=params['attention']['word_conv_kernel'],
+ padding=params['attention']['word_conv_kernel']//2)
+
+ self.word_state_weight = nn.Linear(self.hidden_size, self.hidden_size)
+ self.word_embedding_weight = nn.Linear(self.input_size, self.hidden_size)
+ self.word_context_weight = nn.Linear(self.out_channel, self.hidden_size)
+ # self.counting_context_weight = nn.Linear(self.counting_num, self.hidden_size)
+ self.word_convert = nn.Linear(self.hidden_size, self.word_num)
+
+ if params['dropout']:
+ self.dropout = nn.Dropout(params['dropout_ratio'])
+
+ def forward(self, cnn_features, labels, images_mask, label_mask, is_train=True):
+ batch_size, num_steps = labels.shape
+ height, width = cnn_features.shape[2:]
+ word_probs = torch.zeros((batch_size, num_steps, self.word_num)).to(device=self.device)
+ images_mask = images_mask[:, :, ::self.ratio, ::self.ratio]
+
+ word_alpha_sum = torch.zeros((batch_size, 1, height, width)).to(device=self.device)
+ word_alphas = torch.zeros((batch_size, num_steps, height, width)).to(device=self.device)
+ hidden = self.init_hidden(cnn_features, images_mask)
+
+ cnn_features_trans = self.encoder_feature_conv(cnn_features)
+
+ word_context_vec_list, label_list, word_out_state_list = [], [], []
+ if is_train:
+ for i in range(num_steps):
+ word_embedding = self.embedding(labels[:, i-1]) if i else self.embedding(torch.ones([batch_size]).long().to(self.device))
+ hidden = self.word_input_gru(word_embedding, hidden)
+ word_context_vec, word_alpha, word_alpha_sum = self.word_attention(cnn_features, cnn_features_trans, hidden,
+ word_alpha_sum, images_mask)
+ hidden = self.word_output_gru(word_context_vec, hidden)
+
+ current_state = self.word_state_weight(hidden)
+ word_weighted_embedding = self.word_embedding_weight(word_embedding)
+ word_context_weighted = self.word_context_weight(word_context_vec)
+
+ if self.params['dropout']:
+ word_out_state = self.dropout(current_state + word_weighted_embedding + word_context_weighted)
+ else:
+ word_out_state = current_state + word_weighted_embedding + word_context_weighted
+
+ word_prob = self.word_convert(word_out_state)
+ word_probs[:, i] = word_prob
+ word_alphas[:, i] = word_alpha
+
+ word_context_vec_list.append(word_context_vec)
+ label_list.append(labels[:, i])
+ word_out_state_list.append(word_out_state)
+
+ else:
+ word_embedding = self.embedding(torch.ones([batch_size]).long().to(device=self.device))
+ for i in range(num_steps):
+ hidden = self.word_input_gru(word_embedding, hidden)
+ word_context_vec, word_alpha, word_alpha_sum = self.word_attention(cnn_features, cnn_features_trans, hidden,
+ word_alpha_sum, images_mask)
+ hidden = self.word_output_gru(word_context_vec, hidden)
+
+ current_state = self.word_state_weight(hidden)
+ word_weighted_embedding = self.word_embedding_weight(word_embedding)
+ word_context_weighted = self.word_context_weight(word_context_vec)
+
+ if self.params['dropout']:
+ word_out_state = self.dropout(current_state + word_weighted_embedding + word_context_weighted)
+ else:
+ word_out_state = current_state + word_weighted_embedding + word_context_weighted
+
+ word_prob = self.word_convert(word_out_state)
+ _, word = word_prob.max(1)
+ word_embedding = self.embedding(word)
+ word_probs[:, i] = word_prob
+ word_alphas[:, i] = word_alpha
+ return word_probs, word_alphas, (self.embedding.weight, word_context_vec_list, word_out_state_list, label_list)
+
+ def init_hidden(self, features, feature_mask):
+ average = (features * feature_mask).sum(-1).sum(-1) / feature_mask.sum(-1).sum(-1)
+ average = self.init_weight(average)
+ return torch.tanh(average)
+
+
diff --git a/models/decoder/decoder_v3.py b/models/decoder/decoder_v3.py
new file mode 100644
index 0000000..e6cee0f
--- /dev/null
+++ b/models/decoder/decoder_v3.py
@@ -0,0 +1,134 @@
+import torch
+import torch.nn as nn
+from models.decoder.attention import Attention
+import models
+
+class Decoder_v1(nn.Module):
+ def __init__(self, params):
+ super(Decoder_v1, self).__init__()
+ self.params = params
+ self.input_size = params['decoder']['input_size']
+ self.hidden_size = params['decoder']['hidden_size']
+ self.out_channel = params['encoder']['out_channel']
+ self.attention_dim = params['attention']['attention_dim']
+ self.dropout_prob = params['dropout']
+ self.device = params['device']
+ self.word_num = params['word_num']
+ # self.counting_num = params['counting_decoder']['out_channel']
+
+ """经过cnn后 长宽与原始尺寸比缩小的比例"""
+ self.ratio = params['densenet']['ratio']
+
+ # init hidden state
+ self.init_weight = nn.Linear(self.out_channel, self.hidden_size)
+ # word embedding
+ self.embedding = nn.Embedding(self.word_num, self.input_size)
+ # word gru
+ self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size)
+ self.word_output_gru = nn.GRUCell(self.out_channel, self.hidden_size)
+ # attention
+ self.word_attention = Attention(params)
+ self.encoder_feature_conv = nn.Conv2d(self.out_channel, self.attention_dim,
+ kernel_size=params['attention']['word_conv_kernel'],
+ padding=params['attention']['word_conv_kernel']//2)
+
+ self.word_state_weight = nn.Linear(self.hidden_size, self.hidden_size)
+ self.word_embedding_weight = nn.Linear(self.input_size, self.hidden_size)
+ self.word_context_weight = nn.Linear(self.out_channel, self.hidden_size)
+ # self.counting_context_weight = nn.Linear(self.counting_num, self.hidden_size)
+ self.word_convert = nn.Linear(self.hidden_size, self.word_num)
+
+ if params['dropout']:
+ self.dropout = nn.Dropout(params['dropout_ratio'])
+
+ if "counting_decoder" in self.params and self.params['counting_decoder']['use_flag']:
+ self.counting_context_weight = nn.Linear(self.params['counting_decoder']['out_channel'], self.params['decoder']['hidden_size'])
+ self.counting_decoder1 = getattr(models, "counting_decoder")(self.params['counting_decoder']['in_channel'], self.params['counting_decoder']['out_channel'], 3)
+ self.counting_decoder2 = getattr(models, "counting_decoder")(self.params['counting_decoder']['in_channel'], self.params['counting_decoder']['out_channel'], 5)
+
+ self.counting_loss = nn.SmoothL1Loss(reduction='mean')
+
+ def counting(self, cnn_features, images_mask, counting_labels):
+ counting_mask = images_mask[:, :, ::self.ratio, ::self.ratio]
+
+ counting_preds1, _ = self.counting_decoder1(cnn_features, counting_mask)
+ counting_preds2, _ = self.counting_decoder2(cnn_features, counting_mask)
+ counting_preds = (counting_preds1 + counting_preds2) / 2
+ counting_loss = self.counting_loss(counting_preds1, counting_labels) + self.counting_loss(counting_preds2, counting_labels) \
+ + self.counting_loss(counting_preds, counting_labels)
+
+ return counting_loss, counting_preds
+
+ def forward(self, cnn_features, labels, images_mask, label_mask, counting_labels=None, is_train=True):
+ counting_loss, counting_context_weighted = 0., 0.
+ if "counting_decoder" in self.params and self.params['counting_decoder']['use_flag']:
+ counting_loss, counting_preds = self.counting(cnn_features, images_mask, counting_labels)
+ counting_context_weighted = self.counting_context_weight(counting_preds)
+
+ batch_size, num_steps = labels.shape
+ height, width = cnn_features.shape[2:]
+ word_probs = torch.zeros((batch_size, num_steps, self.word_num)).to(device=self.device)
+ images_mask = images_mask[:, :, ::self.ratio, ::self.ratio]
+
+ word_alpha_sum = torch.zeros((batch_size, 1, height, width)).to(device=self.device)
+ word_alphas = torch.zeros((batch_size, num_steps, height, width)).to(device=self.device)
+ hidden = self.init_hidden(cnn_features, images_mask)
+
+ cnn_features_trans = self.encoder_feature_conv(cnn_features)
+
+ word_context_vec_list, label_list, word_out_state_list = [], [], []
+ if is_train:
+ for i in range(num_steps):
+ word_embedding = self.embedding(labels[:, i-1]) if i else self.embedding(torch.ones([batch_size]).long().to(self.device))
+ hidden = self.word_input_gru(word_embedding, hidden)
+ word_context_vec, word_alpha, word_alpha_sum = self.word_attention(cnn_features, cnn_features_trans, hidden,
+ word_alpha_sum, images_mask)
+ hidden = self.word_output_gru(word_context_vec, hidden)
+
+ current_state = self.word_state_weight(hidden)
+ word_weighted_embedding = self.word_embedding_weight(word_embedding)
+ word_context_weighted = self.word_context_weight(word_context_vec)
+
+ if self.params['dropout']:
+ word_out_state = self.dropout(current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted)
+ else:
+ word_out_state = current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted
+
+ word_prob = self.word_convert(word_out_state)
+ word_probs[:, i] = word_prob
+ word_alphas[:, i] = word_alpha
+
+ word_context_vec_list.append(word_context_vec)
+ label_list.append(labels[:, i])
+ word_out_state_list.append(word_out_state)
+
+ else:
+ word_embedding = self.embedding(torch.ones([batch_size]).long().to(device=self.device))
+ for i in range(num_steps):
+ hidden = self.word_input_gru(word_embedding, hidden)
+ word_context_vec, word_alpha, word_alpha_sum = self.word_attention(cnn_features, cnn_features_trans, hidden,
+ word_alpha_sum, images_mask)
+ hidden = self.word_output_gru(word_context_vec, hidden)
+
+ current_state = self.word_state_weight(hidden)
+ word_weighted_embedding = self.word_embedding_weight(word_embedding)
+ word_context_weighted = self.word_context_weight(word_context_vec)
+
+ if self.params['dropout']:
+ word_out_state = self.dropout(current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted)
+ else:
+ word_out_state = current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted
+
+ word_prob = self.word_convert(word_out_state)
+ _, word = word_prob.max(1)
+ word_embedding = self.embedding(word)
+ word_probs[:, i] = word_prob
+ word_alphas[:, i] = word_alpha
+ return word_probs, word_alphas, (self.embedding.weight, word_context_vec_list, word_out_state_list, label_list, counting_loss)
+
+ def init_hidden(self, features, feature_mask):
+ average = (features * feature_mask).sum(-1).sum(-1) / feature_mask.sum(-1).sum(-1)
+ average = self.init_weight(average)
+ return torch.tanh(average)
+
+
diff --git a/models/infer_model.py b/models/infer_model.py
deleted file mode 100644
index e285080..0000000
--- a/models/infer_model.py
+++ /dev/null
@@ -1,171 +0,0 @@
-import os
-import cv2
-import torch
-import torch.nn as nn
-import math
-import numpy as np
-
-
-from models.densenet import DenseNet
-from models.decoder.attention import Attention
-from utils import draw_attention_map, draw_counting_map
-
-
-class Inference(nn.Module):
- def __init__(self, params=None, draw_map=False):
- super(Inference, self).__init__()
- self.params = params
- self.draw_map = draw_map
- self.use_label_mask = params['use_label_mask']
- self.encoder = DenseNet(params=self.params)
- self.in_channel = params['counting_decoder']['in_channel']
- self.out_channel = params['counting_decoder']['out_channel']
- self.device = params['device']
- self.decoder = decoder_dict[params['decoder']['net']](params=self.params)
-
- """经过cnn后 长宽与原始尺寸比缩小的比例"""
- self.ratio = params['densenet']['ratio']
-
- with open(params['word_path']) as f:
- words = f.readlines()
- print(f'共 {len(words)} 类符号。')
- self.words_index_dict = {i: words[i].strip() for i in range(len(words))}
- self.cal_mae = nn.L1Loss(reduction='mean')
- self.cal_mse = nn.MSELoss(reduction='mean')
- self.ignore_symbos = self.params['words'].encode(['','','{','}','_','^'])
-
- def forward(self, images, labels, name, is_train=False):
- cnn_features = self.encoder(images)
- batch_size, _, height, width = cnn_features.shape
- counting_preds1, counting_maps1 = self.counting_decoder1(cnn_features, None)
- counting_preds2, counting_maps2 = self.counting_decoder2(cnn_features, None)
- counting_preds = (counting_preds1 + counting_preds2) / 2
- counting_maps = (counting_maps1 + counting_maps2) / 2
-
- mae = self.cal_mae(counting_preds, gen_counting_label(labels, self.out_channel, True, ignore_symbos=self.ignore_symbos)).item()
- mse = math.sqrt(self.cal_mse(counting_preds, gen_counting_label(labels, self.out_channel, True, ignore_symbos=self.ignore_symbos)).item())
-
- word_probs, alphas = self.decoder(cnn_features, counting_preds, counting_map=counting_maps, is_train=is_train)
-
- if self.params['decoder']['net'] in ['AttFusionLocDecoder_v2','AttFusionLocDecoder_v3']:
- word_alphas, loc_alphas = alphas
- else:
- word_alphas = alphas
-
- if self.draw_map:
- if not os.path.exists(os.path.join(self.params['attention_map_vis_path'], name)):
- os.makedirs(os.path.join(self.params['attention_map_vis_path'], name), exist_ok=True)
- if not os.path.exists(os.path.join(self.params['counting_map_vis_path'], name)):
- os.makedirs(os.path.join(self.params['counting_map_vis_path'], name), exist_ok=True)
- for i in range(images.shape[0]):
- img = 255 - images[i][0].detach().cpu().numpy() * 255
- # draw attention_map and loc attention_map
- for step in range(len(word_probs)):
- word_atten = word_alphas[step][0].detach().cpu().numpy()
- word_heatmap = draw_attention_map(img, word_atten)
-
- if self.params['decoder']['net'] in ['AttFusionLocDecoder_v2', 'AttFusionLocDecoder_v3']:
- loc_atten = loc_alphas[step][0].detach().cpu().numpy()
- loc_heatmap = draw_attention_map(img, loc_atten)
-
- h, w, c = word_heatmap.shape
- edge = np.ones([10,w,c]) * 255
-
- out = np.concatenate([word_heatmap, edge, loc_heatmap], axis=0)
- else:
- out = word_heatmap
-
- cv2.imwrite(os.path.join(self.params['attention_map_vis_path'], name, f'{step}.jpg'), out)
-
- # draw counting_map
- for idx in range(self.out_channel):
- counting_map = counting_maps[0].permute(1,2,0)[:,:,idx].detach().cpu()
- counting_heatmap = draw_counting_map(img, counting_map)
- img_name = 'symbol_' + self.words_index_dict[idx] + '_map.jpg'
- cv2.imwrite(os.path.join(self.params['counting_map_vis_path'], name, img_name), counting_heatmap)
-
- return word_probs, word_alphas, mae, mse
-
-
-class AttDecoder(nn.Module):
- def __init__(self, params):
- super(AttDecoder, self).__init__()
- self.params = params
- self.input_size = params['decoder']['input_size']
- self.hidden_size = params['decoder']['hidden_size']
- self.out_channel = params['encoder']['out_channel']
- self.attention_dim = params['attention']['attention_dim']
- self.dropout_prob = params['dropout']
- self.device = params['device']
- self.word_num = params['word_num']
- self.ratio = params['densenet']['ratio']
-
- self.init_weight = nn.Linear(self.out_channel, self.hidden_size)
- self.embedding = nn.Embedding(self.word_num, self.input_size)
- self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size)
- self.encoder_feature_conv = nn.Conv2d(self.out_channel, self.attention_dim, kernel_size=1)
- self.word_attention = Attention(params)
-
- self.word_state_weight = nn.Linear(self.hidden_size, self.hidden_size)
- self.word_embedding_weight = nn.Linear(self.input_size, self.hidden_size)
- self.word_context_weight = nn.Linear(self.out_channel, self.hidden_size)
- self.counting_context_weight = nn.Linear(self.word_num, self.hidden_size)
- self.word_convert = nn.Linear(self.hidden_size, self.word_num)
-
- if params['dropout']:
- self.dropout = nn.Dropout(params['dropout_ratio'])
-
- def forward(self, cnn_features, counting_preds, is_train=False):
- batch_size, _, height, width = cnn_features.shape
- image_mask = torch.ones((batch_size, 1, height, width)).to(self.device)
-
- cnn_features_trans = self.encoder_feature_conv(cnn_features)
- position_embedding = PositionEmbeddingSine(256, normalize=True)
- pos = position_embedding(cnn_features_trans, image_mask[:,0,:,:])
- cnn_features_trans = cnn_features_trans + pos
-
- word_alpha_sum = torch.zeros((batch_size, 1, height, width)).to(device=self.device)
- hidden = self.init_hidden(cnn_features, image_mask)
- word_embedding = self.embedding(torch.ones([batch_size]).long().to(device=self.device))
- counting_context_weighted = self.counting_context_weight(counting_preds)
- word_probs = []
- word_alphas = []
-
- i = 0
- while i < 200:
- hidden = self.word_input_gru(word_embedding, hidden)
- word_context_vec, word_alpha, word_alpha_sum = self.word_attention(cnn_features, cnn_features_trans, hidden,
- word_alpha_sum, image_mask)
-
- current_state = self.word_state_weight(hidden)
- word_weighted_embedding = self.word_embedding_weight(word_embedding)
- word_context_weighted = self.word_context_weight(word_context_vec)
-
- if self.params['dropout']:
- word_out_state = self.dropout(current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted)
- else:
- word_out_state = current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted
-
- word_prob = self.word_convert(word_out_state)
- _, word = word_prob.max(1)
- word_embedding = self.embedding(word)
- if word.item() == 0:
- return word_probs, word_alphas
- word_alphas.append(word_alpha)
- word_probs.append(word)
- i+=1
- return word_probs, word_alphas
-
- def init_hidden(self, features, feature_mask):
- average = (features * feature_mask).sum(-1).sum(-1) / feature_mask.sum(-1).sum(-1)
- average = self.init_weight(average)
- return torch.tanh(average)
-
-
-decoder_dict = {
- 'AttDecoder': AttDecoder,
- # 'AttFusionLocDecoder_v1': AttFusionLocDecoder_v1,
- # 'AttFusionLocDecoder_v2': AttFusionLocDecoder_v2,
- # 'AttFusionLocDecoder_v3': AttFusionLocDecoder_v3,
- # 'AttFusionLocDecoder_v4': AttFusionLocDecoder_v4
-}
\ No newline at end of file
diff --git a/train.py b/train.py
index c88f8aa..d1ebac0 100644
--- a/train.py
+++ b/train.py
@@ -55,7 +55,8 @@
if params['finetune']:
print('加载预训练模型权重')
print(f'预训练权重路径: {params["checkpoint"]}')
- load_checkpoint(model, optimizer, params['checkpoint'])
+ if params["checkpoint"]:
+ load_checkpoint(model, optimizer, params['checkpoint'])
if not args.check:
if not os.path.exists(os.path.join(params['checkpoint_dir'], model.name)):
@@ -67,12 +68,12 @@
min_score = 0
min_step = 0
rate_2014, rate_2016, rate_2019 = 0.55, 0.54, 0.55
-rate_2014, rate_2016, rate_2019 = 0.50, 0.50, 0.50
+rate_2014, rate_2016, rate_2019 = 0.54, 0.54, 0.54
# init_epoch = 0 if not params['finetune'] else int(params['checkpoint'].split('_')[-1].split('.')[0])
if args.val:
epoch = 1
- model.load_state_dict(torch.load("checkpoints/v1_l1-loss_2022-11-28-23-38_decoder-Decoder_v1/2016_v1_l1-loss_2022-11-28-23-38_decoder-Decoder_v1_WordRate-0.9057_ExpRate-0.5405_182.pth", map_location="cpu")['model'])
+ # model.load_state_dict(torch.load("checkpoints/v1_l2-loss_2022-11-30-09-57_decoder-Decoder_v1/2016_v1_l2-loss_2022-11-30-09-57_decoder-Decoder_v1_WordRate-0.9094_ExpRate-0.5562_190.pth", map_location="cpu")['model'])
print()
eval_loss, eval_word_score, eval_expRate = eval(params, model, epoch, eval_loader_14)
print(f'2014 Epoch: {epoch + 1} loss: {eval_loss:.4f} word score: {eval_word_score:.4f} ExpRate: {eval_expRate:.4f}')
diff --git a/training.py b/training.py
index e97c234..45a3d64 100644
--- a/training.py
+++ b/training.py
@@ -1,26 +1,45 @@
import torch
from tqdm import tqdm
from utils import update_lr, Meter, cal_score
-
-
-def train(params, model, optimizer, epoch, train_loader, writer=None):
+from torch import nn
+from copy import deepcopy
+
+shadow_model = {}
+
+def finetune_part(model: nn.Module, name):
+ for k, v in model.named_parameters():
+ if name in k:
+ v.requires_grad = True
+ else:
+ v.requires_grad = False
+ if k in shadow_model:
+ assert torch.equal(shadow_model[k], v), "find params change!"
+ else:
+ shadow_model[k] = deepcopy(v)
+ return
+
+def train(params, model, optimizer:torch.optim.Optimizer, epoch, train_loader, writer=None):
model.train()
device = params['device']
loss_meter = Meter()
word_right, exp_right, length, cal_num = 0, 0, 0, 0
+ if params['finetune']:
+ finetune_part(model, 'counting')
+
with tqdm(train_loader, total=len(train_loader)//params['train_parts']) as pbar:
- for batch_idx, (images, image_masks, labels, label_masks) in enumerate(pbar):
- images, image_masks, labels, label_masks = images.to(device), image_masks.to(
- device), labels.to(device), label_masks.to(device)
+ for batch_idx, (images, image_masks, labels, label_masks, matrix, counting_labels) in enumerate(pbar):
+ images, image_masks, labels, label_masks, matrix, counting_labels = \
+ images.to(device, non_blocking=True), image_masks.to(device, non_blocking=True), \
+ labels.to(device, non_blocking=True), label_masks.to(device, non_blocking=True), \
+ matrix.to(device, non_blocking=True), counting_labels.to(device, non_blocking=True)
batch, time = labels.shape[:2]
if not 'lr_decay' in params or params['lr_decay'] == 'cosine':
update_lr(optimizer, epoch, batch_idx, len(train_loader), params['epochs'], params['lr'])
optimizer.zero_grad()
- probs, loss = model(images, image_masks, labels, label_masks)
- word_loss, sim_loss = loss
- loss = word_loss + sim_loss
-
+ probs, loss = model(images, image_masks, labels, label_masks, matrix=matrix, counting_labels=counting_labels)
+ word_loss, sim_loss, context_loss, word_state_loss, counting_loss = loss
+ loss = word_loss + sim_loss + context_loss + word_state_loss + counting_loss
loss.backward()
if params['gradient_clip']:
@@ -33,17 +52,24 @@ def train(params, model, optimizer, epoch, train_loader, writer=None):
exp_right = exp_right + ExpRate * batch
length = length + time
cal_num = cal_num + batch
-
+
+ if isinstance(sim_loss, torch.Tensor):
+ sim_loss = sim_loss.item()
+
if writer:
current_step = epoch * len(train_loader) // params['train_parts'] + batch_idx + 1
writer.add_scalar('train/loss', loss.item(), current_step)
+ writer.add_scalar('train/sim', sim_loss, current_step)
+ writer.add_scalar('train/counting', counting_loss, current_step)
+ writer.add_scalar('train/context', context_loss, current_step)
+ writer.add_scalar('train/word', word_state_loss, current_step)
writer.add_scalar('train/WordRate', wordRate, current_step)
writer.add_scalar('train/ExpRate', ExpRate, current_step)
writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], current_step)
-
- string = f'{epoch + 1} word_loss:{word_loss.item():.3f} sim_loss: {sim_loss.item():.3f} '
+
+ string = f'{epoch + 1} word_loss:{word_loss.item():.3f} sim_loss: {sim_loss:.3f} '
string += f'WRate:{word_right / length:.3f} ERate:{exp_right / cal_num:.3f}'
-
+
pbar.set_description(string)
if batch_idx >= len(train_loader) // params['train_parts']:
break
@@ -62,13 +88,13 @@ def eval(params, model, epoch, eval_loader, writer=None):
word_right, exp_right, length, cal_num = 0, 0, 0, 0
with tqdm(eval_loader, total=len(eval_loader)//params['valid_parts']) as pbar, torch.no_grad():
- for batch_idx, (images, image_masks, labels, label_masks) in enumerate(pbar):
- images, image_masks, labels, label_masks = images.to(device), image_masks.to(
- device), labels.to(device), label_masks.to(device)
+ for batch_idx, (images, image_masks, labels, label_masks, _, counting_labels) in enumerate(pbar):
+ images, image_masks, labels, label_masks, counting_labels = images.to(device), image_masks.to(
+ device), labels.to(device), label_masks.to(device), counting_labels.to(device)
batch, time = labels.shape[:2]
- probs, loss = model(images, image_masks, labels, label_masks, is_train=False)
+ probs, loss = model(images, image_masks, labels, label_masks, counting_labels=counting_labels, is_train=False)
- word_loss, sim_loss = loss
+ word_loss, sim_loss, _, _, counting_loss = loss
loss = word_loss + sim_loss
loss_meter.add(loss.item())
@@ -78,6 +104,9 @@ def eval(params, model, epoch, eval_loader, writer=None):
length = length + time
cal_num = cal_num + batch
+ if isinstance(sim_loss, torch.Tensor):
+ sim_loss = sim_loss.item()
+
if writer:
current_step = epoch * len(eval_loader)//params['valid_parts'] + batch_idx + 1
writer.add_scalar('eval/word_loss', word_loss.item(), current_step)
@@ -85,7 +114,7 @@ def eval(params, model, epoch, eval_loader, writer=None):
writer.add_scalar('eval/WordRate', wordRate, current_step)
writer.add_scalar('eval/ExpRate', ExpRate, current_step)
- pbar.set_description(f'{epoch+1} word_loss:{word_loss.item():.4f} sim_loss:{sim_loss.item():.4f}'
+ pbar.set_description(f'{epoch+1} word_loss:{word_loss.item():.4f} sim_loss:{sim_loss:.4f}'
f' WRate:{word_right / length:.4f} ERate:{exp_right / cal_num:.4f}')
if batch_idx >= len(eval_loader) // params['valid_parts']:
break