diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..c87dcaa53faf5f8e68ea424ea18ae68fad94a2a9 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/fig_teaser.png filter=lfs diff=lfs merge=lfs -text +triton-2.0.0-cp310-cp310-win_amd64.whl filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..061fb1affd69161b70eb17c92f8172eb60cf92d6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,192 @@ +# Initially taken from Github's Python gitignore file + +ckpts +sam_pt +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# tests and logs +tests/fixtures/cached_*_text.txt +logs/ +lightning_logs/ +lang_code_data/ + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# vscode +.vs +.vscode + +# Pycharm +.idea + +# TF code +tensorflow_code + +# Models +proc_data + +# examples +runs +/runs_old +/wandb +/examples/runs +/examples/**/*.args +/examples/rag/sweep + +# data +/data +serialization_dir + +# emacs +*.*~ +debug.env + +# vim +.*.swp + +#ctags +tags + +# pre-commit +.pre-commit* + +# .lock +*.lock + +# DS_Store (MacOS) +.DS_Store +# RL pipelines may produce mp4 outputs +*.mp4 + +# dependencies +/transformers + +# ruff +.ruff_cache + +# ckpts +*.ckpt + +outputs/* + +NeuS/exp/* +NeuS/test_scenes/* +NeuS/mesh2tex/* +neus_configs +vast/* +render_results +experiments/* +ckpts/* +neus/* +instant-nsr-pl/exp/* \ No newline at end of file diff --git a/1gpu.yaml b/1gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ac5e9f453e944be76e55ce2a3992788f054f5cff --- /dev/null +++ b/1gpu.yaml @@ -0,0 +1,15 @@ +compute_environment: LOCAL_MACHINE +distributed_type: 'NO' +downcast_bf16: 'no' +gpu_ids: '0' +machine_rank: 0 +main_training_function: main +mixed_precision: 'no' +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0ad25db4bd1d86c452db3f9602ccdbe172438f52 --- /dev/null +++ b/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/README.md b/README.md index 3d58fe5179c6286fa413653385bf8b7cb888b366..b39d2664101169fb8ee35702c5d80983b2c1d364 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,211 @@ --- title: Wonder3D -emoji: 🏢 -colorFrom: blue -colorTo: pink +app_file: gradio_app_mv.py sdk: gradio -sdk_version: 4.15.0 -app_file: app.py -pinned: false +sdk_version: 3.50.2 --- +**中文版本 [中文](README_zh.md)** +# Wonder3D +Single Image to 3D using Cross-Domain Diffusion +## [Paper](https://arxiv.org/abs/2310.15008) | [Project page](https://www.xxlong.site/Wonder3D/) | [Hugging Face Demo](https://huggingface.co/spaces/flamehaze1115/Wonder3D-demo) | [Colab from @camenduru](https://github.com/camenduru/Wonder3D-colab) -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +![](assets/fig_teaser.png) + +Wonder3D reconstructs highly-detailed textured meshes from a single-view image in only 2 ∼ 3 minutes. Wonder3D first generates consistent multi-view normal maps with corresponding color images via a cross-domain diffusion model, and then leverages a novel normal fusion method to achieve fast and high-quality reconstruction. + +## Usage +```bash + +# First clone the repo, and use the commands in the repo + +import torch +import requests +from PIL import Image +import numpy as np +from torchvision.utils import make_grid, save_image +from diffusers import DiffusionPipeline # only tested on diffusers[torch]==0.19.3, may have conflicts with newer versions of diffusers + +def load_wonder3d_pipeline(): + + pipeline = DiffusionPipeline.from_pretrained( + 'flamehaze1115/wonder3d-v1.0', # or use local checkpoint './ckpts' + custom_pipeline='flamehaze1115/wonder3d-pipeline', + torch_dtype=torch.float16 + ) + + # enable xformers + pipeline.unet.enable_xformers_memory_efficient_attention() + + if torch.cuda.is_available(): + pipeline.to('cuda:0') + return pipeline + +pipeline = load_wonder3d_pipeline() + +# Download an example image. +cond = Image.open(requests.get("https://d.skis.ltd/nrp/sample-data/lysol.png", stream=True).raw) + +# The object should be located in the center and resized to 80% of image height. +cond = Image.fromarray(np.array(cond)[:, :, :3]) + +# Run the pipeline! +images = pipeline(cond, num_inference_steps=20, output_type='pt', guidance_scale=1.0).images + +result = make_grid(images, nrow=6, ncol=2, padding=0, value_range=(0, 1)) + +save_image(result, 'result.png') +``` + +## Collaborations +Our overarching mission is to enhance the speed, affordability, and quality of 3D AIGC, making the creation of 3D content accessible to all. While significant progress has been achieved in the recent years, we acknowledge there is still a substantial journey ahead. We enthusiastically invite you to engage in discussions and explore potential collaborations in any capacity. **If you're interested in connecting or partnering with us, please don't hesitate to reach out via email (xxlong@connect.hku.hk)** . + +## More features + +The repo is still being under construction, thanks for your patience. +- [x] Local gradio demo. +- [x] Detailed tutorial. +- [x] GUI demo for mesh reconstruction +- [x] Windows support +- [x] Docker support + +## Schedule +- [x] Inference code and pretrained models. +- [x] Huggingface demo. +- [ ] New model with higher resolution. + + +### Preparation for inference + +#### Linux System Setup. +```angular2html +conda create -n wonder3d +conda activate wonder3d +pip install -r requirements.txt +pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch +``` +#### Windows System Setup. + +Please switch to branch `main-windows` to see details of windows setup. + +#### Docker Setup +see [docker/README.MD](docker/README.md) + +### Inference +1. Optional. If you have troubles to connect to huggingface. Make sure you have downloaded the following models. +Download the [checkpoints](https://connecthkuhk-my.sharepoint.com/:f:/g/personal/xxlong_connect_hku_hk/Ej7fMT1PwXtKvsELTvDuzuMBebQXEkmf2IwhSjBWtKAJiA) and into the root folder. + +If you are in mainland China, you may download via [aliyun](https://www.alipan.com/s/T4rLUNAVq6V). + +```bash +Wonder3D +|-- ckpts + |-- unet + |-- scheduler + |-- vae + ... +``` +Then modify the file ./configs/mvdiffusion-joint-ortho-6views.yaml, set `pretrained_model_name_or_path="./ckpts"` + +2. Download the [SAM](https://huggingface.co/spaces/abhishek/StableSAM/blob/main/sam_vit_h_4b8939.pth) model. Put it to the ``sam_pt`` folder. +``` +Wonder3D +|-- sam_pt + |-- sam_vit_h_4b8939.pth +``` +3. Predict foreground mask as the alpha channel. We use [Clipdrop](https://clipdrop.co/remove-background) to segment the foreground object interactively. +You may also use `rembg` to remove the backgrounds. +```bash +# !pip install rembg +import rembg +result = rembg.remove(result) +result.show() +``` +4. Run Wonder3d to produce multiview-consistent normal maps and color images. Then you can check the results in the folder `./outputs`. (we use `rembg` to remove backgrounds of the results, but the segmentations are not always perfect. May consider using [Clipdrop](https://clipdrop.co/remove-background) to get masks for the generated normal maps and color images, since the quality of masks will significantly influence the reconstructed mesh quality.) +```bash +accelerate launch --config_file 1gpu.yaml test_mvdiffusion_seq.py \ + --config configs/mvdiffusion-joint-ortho-6views.yaml validation_dataset.root_dir={your_data_path} \ + validation_dataset.filepaths=['your_img_file'] save_dir={your_save_path} +``` + +see example: + +```bash +accelerate launch --config_file 1gpu.yaml test_mvdiffusion_seq.py \ + --config configs/mvdiffusion-joint-ortho-6views.yaml validation_dataset.root_dir=./example_images \ + validation_dataset.filepaths=['owl.png'] save_dir=./outputs +``` + +#### Interactive inference: run your local gradio demo. (Only generate normals and colors without reconstruction) +```bash +python gradio_app_mv.py # generate multi-view normals and colors +``` + +5. Mesh Extraction + +#### Instant-NSR Mesh Extraction + +```bash +cd ./instant-nsr-pl +python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=../{your_save_path}/cropsize-{crop_size}-cfg{guidance_scale:.1f}/ dataset.scene={scene} +``` + +see example: + +```bash +cd ./instant-nsr-pl +python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=../outputs/cropsize-192-cfg1.0/ dataset.scene=owl +``` + +Our generated normals and color images are defined in orthographic views, so the reconstructed mesh is also in orthographic camera space. If you use MeshLab to view the meshes, you can click `Toggle Orthographic Camera` in `View` tab. + +#### Interactive inference: run your local gradio demo. (First generate normals and colors, and then do reconstructions. No need to perform gradio_app_mv.py first.) +```bash +python gradio_app_recon.py +``` + +#### NeuS-based Mesh Extraction + +Since there are many complaints about the Windows setup of instant-nsr-pl, we provide the NeuS-based reconstruction, which may get rid of the requirement problems. + +NeuS consumes less GPU memory and favors smooth surfaces without parameters tuning. However, NeuS consumes more times and its texture may be less sharp. If you are not sensitive to time, we recommend NeuS for optimization due to its robustness. + +```bash +cd ./NeuS +bash run.sh output_folder_path scene_name +``` + +## Common questions +Q: Tips to get better results. +1. Wonder3D is sensitive the facing direciton of input images. By experiments, front-facing images always lead to good reconstruction. +2. Limited by resources, current implemetation only supports limited views (6 views) and low resolution (256x256). Any images will be first resized into 256x256 for generation, so images after such a downsample that still keep clear and sharp features will lead to good results. +3. Images with occlusions will cause worse reconstructions, since 6 views cannot cover the complete object. Images with less occlsuions lead to better results. +4. Increate optimization steps in instant-nsr-pl, modify `trainer.max_steps: 3000` in `instant-nsr-pl/configs/neuralangelo-ortho-wmask.yaml` to more steps like `trainer.max_steps: 10000`. Longer optimization leads to better texture. + +Q: The evelation and azimuth degrees of the generated views? + +A: Unlike that the prior works such as Zero123, SyncDreamer and One2345 adopt object world system, our views are defined in the camera system of the input image. The six views are in the plane with 0 elevation degree in the camera system of the input image. Therefore we don't need to estimate an elevation degree for input image. The azimuth degrees of the six views are 0, 45, 90, 180, -90, -45 respectively. + +Q: The focal length of the generated views? + +A: We assume the input images are captured by orthographic camera, so the generated views are also in orthographic space. This design enables our model to keep strong generlaization on unreal images, but sometimes it may suffer from focal lens distortions on real-captured images. +## Acknowledgement +We have intensively borrow codes from the following repositories. Many thanks to the authors for sharing their codes. +- [stable diffusion](https://github.com/CompVis/stable-diffusion) +- [zero123](https://github.com/cvlab-columbia/zero123) +- [NeuS](https://github.com/Totoro97/NeuS) +- [SyncDreamer](https://github.com/liuyuan-pal/SyncDreamer) +- [instant-nsr-pl](https://github.com/bennyguo/instant-nsr-pl) + +## License +Wonder3D is under [AGPL-3.0](https://www.gnu.org/licenses/agpl-3.0.en.html), so any downstream solution and products (including cloud services) that include wonder3d code or a trained model (both pretrained or custom trained) inside it should be open-sourced to comply with the AGPL conditions. If you have any questions about the usage of Wonder3D, please contact us first. + +## Citation +If you find this repository useful in your project, please cite the following work. :) +``` +@article{long2023wonder3d, + title={Wonder3D: Single Image to 3D using Cross-Domain Diffusion}, + author={Long, Xiaoxiao and Guo, Yuan-Chen and Lin, Cheng and Liu, Yuan and Dou, Zhiyang and Liu, Lingjie and Ma, Yuexin and Zhang, Song-Hai and Habermann, Marc and Theobalt, Christian and others}, + journal={arXiv preprint arXiv:2310.15008}, + year={2023} +} +``` diff --git a/README_zh.md b/README_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..d7f3e24d909bedc3c8c21652fb64e1debbc02d5a --- /dev/null +++ b/README_zh.md @@ -0,0 +1,203 @@ +**其他语言版本 [English](README.md)** + +# Wonder3D +Single Image to 3D using Cross-Domain Diffusion +## [Paper](https://arxiv.org/abs/2310.15008) | [Project page](https://www.xxlong.site/Wonder3D/) | [Hugging Face Demo](https://huggingface.co/spaces/flamehaze1115/Wonder3D-demo) | [Colab from @camenduru](https://github.com/camenduru/Wonder3D-colab) + +![](assets/fig_teaser.png) + +Wonder3D仅需2至3分钟即可从单视图图像中重建出高度详细的纹理网格。Wonder3D首先通过跨域扩散模型生成一致的多视图法线图与相应的彩色图像,然后利用一种新颖的法线融合方法实现快速且高质量的重建。 + +## Usage 使用 +```bash + +import torch +import requests +from PIL import Image +import numpy as np +from torchvision.utils import make_grid, save_image +from diffusers import DiffusionPipeline # only tested on diffusers[torch]==0.19.3, may have conflicts with newer versions of diffusers + +def load_wonder3d_pipeline(): + + pipeline = DiffusionPipeline.from_pretrained( + 'flamehaze1115/wonder3d-v1.0', # or use local checkpoint './ckpts' + custom_pipeline='flamehaze1115/wonder3d-pipeline', + torch_dtype=torch.float16 + ) + + # enable xformers + pipeline.unet.enable_xformers_memory_efficient_attention() + + if torch.cuda.is_available(): + pipeline.to('cuda:0') + return pipeline + +pipeline = load_wonder3d_pipeline() + +# Download an example image. +cond = Image.open(requests.get("https://d.skis.ltd/nrp/sample-data/lysol.png", stream=True).raw) + +# The object should be located in the center and resized to 80% of image height. +cond = Image.fromarray(np.array(cond)[:, :, :3]) + +# Run the pipeline! +images = pipeline(cond, num_inference_steps=20, output_type='pt', guidance_scale=1.0).images + +result = make_grid(images, nrow=6, ncol=2, padding=0, value_range=(0, 1)) + +save_image(result, 'result.png') +``` + +## Collaborations 合作 +我们的总体使命是提高3D人工智能图形生成(3D AIGC)的速度、可负担性和质量,使所有人都能够轻松创建3D内容。尽管近年来取得了显著的进展,我们承认前方仍有很长的路要走。我们热切邀请您参与讨论并在任何方面探索潜在的合作机会。**如果您有兴趣与我们联系或合作,请随时通过电子邮件(xxlong@connect.hku.hk)联系我们**。 + +## More features + +The repo is still being under construction, thanks for your patience. +- [x] Local gradio demo. +- [x] Detailed tutorial. +- [x] GUI demo for mesh reconstruction +- [x] Windows support +- [x] Docker support + +## Schedule +- [x] Inference code and pretrained models. +- [x] Huggingface demo. +- [ ] New model with higher resolution. + + +### Preparation for inference 测试准备 + +#### Linux System Setup. +```angular2html +conda create -n wonder3d +conda activate wonder3d +pip install -r requirements.txt +pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch +``` +#### Windows System Setup. + +请切换到`main-windows`分支以查看Windows设置的详细信息。 + +#### Docker Setup +详见 [docker/README.MD](docker/README.md) + +### Inference +1. 可选。如果您在连接到Hugging Face时遇到问题,请确保已下载以下模型。 +下载[checkpoints](https://connecthkuhk-my.sharepoint.com/:f:/g/personal/xxlong_connect_hku_hk/Ej7fMT1PwXtKvsELTvDuzuMBebQXEkmf2IwhSjBWtKAJiA)并放入根文件夹中。 + +国内用户可下载: [阿里云盘](https://www.alipan.com/s/T4rLUNAVq6V) + +```bash +Wonder3D +|-- ckpts + |-- unet + |-- scheduler + |-- vae + ... +``` +然后更改文件 ./configs/mvdiffusion-joint-ortho-6views.yaml, 设置 `pretrained_model_name_or_path="./ckpts"` + +2. 下载模型 [SAM](https://huggingface.co/spaces/abhishek/StableSAM/blob/main/sam_vit_h_4b8939.pth) . 放置在 ``sam_pt`` 文件夹. +``` +Wonder3D +|-- sam_pt + |-- sam_vit_h_4b8939.pth +``` +3. 预测前景蒙版作为阿尔法通道。我们使用[Clipdrop](https://clipdrop.co/remove-background)来交互地分割前景对象。 +您还可以使用`rembg`来去除背景。 +```bash +# !pip install rembg +import rembg +result = rembg.remove(result) +result.show() +``` +4. 运行Wonder3D以生成多视角一致的法线图和彩色图像。然后,您可以在文件夹`./outputs`中检查结果(我们使用`rembg`去除结果的背景,但分割并不总是完美的。可以考虑使用[Clipdrop](https://clipdrop.co/remove-background)获取生成的法线图和彩色图像的蒙版,因为蒙版的质量将显著影响重建的网格质量)。 +```bash +accelerate launch --config_file 1gpu.yaml test_mvdiffusion_seq.py \ + --config configs/mvdiffusion-joint-ortho-6views.yaml validation_dataset.root_dir={your_data_path} \ + validation_dataset.filepaths=['your_img_file'] save_dir={your_save_path} +``` + +示例: + +```bash +accelerate launch --config_file 1gpu.yaml test_mvdiffusion_seq.py \ + --config configs/mvdiffusion-joint-ortho-6views.yaml validation_dataset.root_dir=./example_images \ + validation_dataset.filepaths=['owl.png'] save_dir=./outputs +``` + +#### 运行本地的Gradio演示。仅生成法线和颜色,无需进行重建。 +```bash +python gradio_app_mv.py # generate multi-view normals and colors +``` + +5. Mesh Extraction + +#### Instant-NSR Mesh Extraction + +```bash +cd ./instant-nsr-pl +python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=../{your_save_path}/cropsize-{crop_size}-cfg{guidance_scale:.1f}/ dataset.scene={scene} +``` + +示例: + +```bash +cd ./instant-nsr-pl +python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=../outputs/cropsize-192-cfg1.0/ dataset.scene=owl +``` + +我们生成的法线图和彩色图像是在正交视图中定义的,因此重建的网格也处于正交摄像机空间。如果您使用MeshLab查看网格,可以在“View”选项卡中单击“Toggle Orthographic Camera”切换到正交相机。 + +#### 运行本地的Gradio演示。首先生成法线和颜色,然后进行重建。无需首先执行`gradio_app_mv.py`。 +```bash +python gradio_app_recon.py +``` + +#### NeuS-based Mesh Extraction + +由于许多用户对于instant-nsr-pl的Windows设置提出了抱怨,我们提供了基于NeuS的重建,这可能消除了一些要求方面的问题。 + +NeuS消耗较少的GPU内存,对平滑表面有利,无需参数调整。然而,NeuS需要更多时间,其纹理可能不够清晰。如果您对时间不太敏感,我们建议由于其稳健性而使用NeuS进行优化。 + +```bash +cd ./NeuS +bash run.sh output_folder_path scene_name +``` + +## 常见问题 +**获取更好结果的提示:** +1. **图片朝向方向敏感:** Wonder3D对输入图像的面向方向敏感。通过实验证明,面向前方的图像通常会导致良好的重建结果。 +2. **图像分辨率:** 受资源限制,当前实现仅支持有限的视图(6个视图)和低分辨率(256x256)。任何图像都将首先调整大小为256x256进行生成,因此在这样的降采样后仍然保持清晰而锐利特征的图像将导致良好的结果。 +3. **处理遮挡:** 具有遮挡的图像会导致更差的重建,因为6个视图无法完全覆盖整个对象。具有较少遮挡的图像通常会产生更好的结果。 +4. **增加instant-nsr-pl中的优化步骤:** 在instant-nsr-pl中增加优化步骤。在`instant-nsr-pl/configs/neuralangelo-ortho-wmask.yaml`中修改`trainer.max_steps: 3000`为更多步骤,例如`trainer.max_steps: 10000`。更长的优化步骤会导致更好的纹理。 + +**生成视图信息:** +- **仰角和方位角度:** 与Zero123、SyncDreamer和One2345等先前作品采用对象世界系统不同,我们的视图是在输入图像的相机系统中定义的。六个视图在输入图像的相机系统中的平面上,仰角为0度。因此,我们不需要为输入图像估算仰角。六个视图的方位角度分别为0、45、90、180、-90、-45。 + +**生成视图的焦距:** +- 我们假设输入图像是由正交相机捕获的,因此生成的视图也在正交空间中。这种设计使得我们的模型能够在虚构图像上保持强大的泛化能力,但有时可能在实际捕获的图像上受到焦距镜头畸变的影响。 + +## 致谢 +We have intensively borrow codes from the following repositories. Many thanks to the authors for sharing their codes. +- [stable diffusion](https://github.com/CompVis/stable-diffusion) +- [zero123](https://github.com/cvlab-columbia/zero123) +- [NeuS](https://github.com/Totoro97/NeuS) +- [SyncDreamer](https://github.com/liuyuan-pal/SyncDreamer) +- [instant-nsr-pl](https://github.com/bennyguo/instant-nsr-pl) + +## 协议 +Wonder3D采用[AGPL-3.0](https://www.gnu.org/licenses/agpl-3.0.en.html)许可,因此任何包含Wonder3D代码或其中训练的模型(无论是预训练还是定制训练)的下游解决方案和产品(包括云服务)都应该开源以符合AGPL条件。如果您对Wonder3D的使用有任何疑问,请首先与我们联系。 + +## 引用 +如果您在项目中发现这个项目对您有用,请引用以下工作。 :) +``` +@article{long2023wonder3d, + title={Wonder3D: Single Image to 3D using Cross-Domain Diffusion}, + author={Long, Xiaoxiao and Guo, Yuan-Chen and Lin, Cheng and Liu, Yuan and Dou, Zhiyang and Liu, Lingjie and Ma, Yuexin and Zhang, Song-Hai and Habermann, Marc and Theobalt, Christian and others}, + journal={arXiv preprint arXiv:2310.15008}, + year={2023} +} +``` diff --git a/assets/fig_teaser.png b/assets/fig_teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..40cc76812117e7ae142042eb4f03769ea2911d78 --- /dev/null +++ b/assets/fig_teaser.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e366d3fe06124b2f36ee43aca4da522a42e6ebf7d776cc0f5e8d0974cdc2971b +size 1271319 diff --git a/configs/mvdiffusion-joint-ortho-6views.yaml b/configs/mvdiffusion-joint-ortho-6views.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91ae6ac57d45fb410158102341a4d6d9ec5f2976 --- /dev/null +++ b/configs/mvdiffusion-joint-ortho-6views.yaml @@ -0,0 +1,42 @@ +pretrained_model_name_or_path: 'flamehaze1115/wonder3d-v1.0' # or './ckpts' +revision: null +validation_dataset: + root_dir: "./example_images" # the folder path stores testing images + num_views: 6 + bg_color: 'white' + img_wh: [256, 256] + num_validation_samples: 1000 + crop_size: 192 + filepaths: ['owl.png'] # the test image names. leave it empty, test all images in the folder + +save_dir: 'outputs/' + +pred_type: 'joint' +seed: 42 +validation_batch_size: 1 +dataloader_num_workers: 64 + +local_rank: -1 + +pipe_kwargs: + camera_embedding_type: 'e_de_da_sincos' + num_views: 6 + +validation_guidance_scales: [1.0] +pipe_validation_kwargs: + eta: 1.0 +validation_grid_nrow: 6 + +unet_from_pretrained_kwargs: + camera_embedding_type: 'e_de_da_sincos' + projection_class_embeddings_input_dim: 10 + num_views: 6 + sample_size: 32 + cd_attention_mid: true + zero_init_conv_in: false + zero_init_camera_projection: false + +num_views: 6 +camera_embedding_type: 'e_de_da_sincos' + +enable_xformers_memory_efficient_attention: true \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..2c02ea6197ba7cdba746089e8eaf4784d640faec --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,56 @@ +# get the development image from nvidia cuda 11.7 +FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04 + +LABEL name="Wonder3D" \ + maintainer="Tiancheng " \ + lastupdate="2024-01-05" + +# create workspace folder and set it as working directory +RUN mkdir -p /workspace +WORKDIR /workspace + +# Set the timezone +ENV DEBIAN_FRONTEND=noninteractive +RUN apt-get update && \ + apt-get install -y tzdata && \ + ln -fs /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ + dpkg-reconfigure --frontend noninteractive tzdata + +# update package lists and install git, wget, vim, libgl1-mesa-glx, and libglib2.0-0 +RUN apt-get update && \ + apt-get install -y git wget vim libgl1-mesa-glx libglib2.0-0 unzip + +# install conda +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + chmod +x Miniconda3-latest-Linux-x86_64.sh && \ + ./Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 && \ + rm Miniconda3-latest-Linux-x86_64.sh + +# update PATH environment variable +ENV PATH="/workspace/miniconda3/bin:${PATH}" + +# initialize conda +RUN conda init bash + +# create and activate conda environment +RUN conda create -n wonder3d python=3.8 && echo "source activate wonder3d" > ~/.bashrc +ENV PATH /workspace/miniconda3/envs/wonder3d/bin:$PATH + + +# clone the repository +RUN git clone https://github.com/xxlong0/Wonder3D.git && \ + cd /workspace/Wonder3D + +# change the working directory to the repository +WORKDIR /workspace/Wonder3D + +# install pytorch 1.13.1 and torchvision +RUN pip install -r docker/requirements.txt + +# install the specific version of nerfacc corresponding to torch 1.13.0 and cuda 11.7, otherwise the nerfacc will freeze during cuda setup +RUN pip install nerfacc==0.3.3 -f https://nerfacc-bucket.s3.us-west-2.amazonaws.com/whl/torch-1.13.0_cu117.html + +# install tiny cuda during docker setup will cause error, need to install it manually in the container +# RUN pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch + + diff --git a/docker/README.md b/docker/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9c88d6aace3688d5db09ef7080118dff74b279d1 --- /dev/null +++ b/docker/README.md @@ -0,0 +1,57 @@ +# Docker setup + +This docker setup is tested on Ubunu20.04. + +make sure you are under directory yourworkspace/Wonder3D/ + +run + +`docker build --no-cache -t wonder3d/deploy:cuda11.7 -f docker/Dockerfile .` + +then run + +`docker run --gpus all -it wonder3d/deploy:cuda11.7 bash` + + +## Nvidia Container Toolkit setup + +You will have trouble enabling gpu for docker if you haven't installed **NVIDIA Container Toolkit** on you local machine before. You can skip this section if you have already installed it. Follow the instruction in this website to install it. + +https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html + +or you can run the following command to install it with apt: + +1.Configure the production repository: + +```bash +curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \ + && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \ + sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \ + sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list +``` + +2.Update the packages list from the repository: + +`sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list` + +3.Install the NVIDIA Container Toolkit packages: + +`sudo apt-get install -y nvidia-container-toolkit` + +Remember to restart the docker: + +`sudo systemctl restart docker` + +now you can run the following command: + +`docker run --gpus all -it wonder3d/deploy:cuda11.7 bash` + + +## Install Tiny Cudann + +After you start the container, run the following command to install tiny cudann. Somehow this pip installation can not be done during the docker build, so you have to do it manually after the docker is started. + +`pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch` + + +Now you should be good to go, good luck and have fun :) diff --git a/docker/requirements.txt b/docker/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1de9e2dfcc6ab57234975f9e8d289d3626592f05 --- /dev/null +++ b/docker/requirements.txt @@ -0,0 +1,36 @@ +--extra-index-url https://download.pytorch.org/whl/cu117 + +# nerfacc==0.3.3, nefacc needs to be installed from the specific location +# see installation part in this link: https://github.com/nerfstudio-project/nerfacc + +torch==1.13.1+cu117 +torchvision==0.14.1+cu117 +diffusers[torch]==0.19.3 +xformers==0.0.16 +transformers>=4.25.1 +bitsandbytes==0.35.4 +decord==0.6.0 +pytorch-lightning<2 +omegaconf==2.2.3 +trimesh==3.9.8 +pyhocon==0.3.57 +icecream==2.1.0 +PyMCubes==0.1.2 +accelerate +modelcards +einops +ftfy +piq +matplotlib +opencv-python +imageio +imageio-ffmpeg +scipy +pyransac3d +torch_efficient_distloss +tensorboard +rembg +segment_anything +gradio==3.50.2 +triton +rich diff --git a/example_images/14_10_29_489_Tiger_1__1.png b/example_images/14_10_29_489_Tiger_1__1.png new file mode 100644 index 0000000000000000000000000000000000000000..866827c406fbf3b05e76ba369e45056f84e945c0 Binary files /dev/null and b/example_images/14_10_29_489_Tiger_1__1.png differ diff --git a/example_images/box.png b/example_images/box.png new file mode 100644 index 0000000000000000000000000000000000000000..a254ff1f569342342d70c7ed6e4b922dadf40274 Binary files /dev/null and b/example_images/box.png differ diff --git a/example_images/bread.png b/example_images/bread.png new file mode 100644 index 0000000000000000000000000000000000000000..5814a3d1c69d9758da1a873ef4d5c6cae703c340 Binary files /dev/null and b/example_images/bread.png differ diff --git a/example_images/cat.png b/example_images/cat.png new file mode 100644 index 0000000000000000000000000000000000000000..3090618ec3b414dafba3d6843bbb951b41cb4356 Binary files /dev/null and b/example_images/cat.png differ diff --git a/example_images/cat_head.png b/example_images/cat_head.png new file mode 100644 index 0000000000000000000000000000000000000000..411fcf61ebf276dcc1de6ab971d1b70ab8fe84ab Binary files /dev/null and b/example_images/cat_head.png differ diff --git a/example_images/chili.png b/example_images/chili.png new file mode 100644 index 0000000000000000000000000000000000000000..5af023cdccf295b7d521082b0aef59ce713f9460 Binary files /dev/null and b/example_images/chili.png differ diff --git a/example_images/duola.png b/example_images/duola.png new file mode 100644 index 0000000000000000000000000000000000000000..bd93331c4e9b3b7923d0e8a311e1d3f4e5a541c4 Binary files /dev/null and b/example_images/duola.png differ diff --git a/example_images/halloween.png b/example_images/halloween.png new file mode 100644 index 0000000000000000000000000000000000000000..7502e0346e7932e53c670b772d361b026350a1ab Binary files /dev/null and b/example_images/halloween.png differ diff --git a/example_images/head.png b/example_images/head.png new file mode 100644 index 0000000000000000000000000000000000000000..373031cf8213279baf82ad79a80cdf4793081d99 Binary files /dev/null and b/example_images/head.png differ diff --git a/example_images/kettle.png b/example_images/kettle.png new file mode 100644 index 0000000000000000000000000000000000000000..de8e12d3ea2ed63a50864879360fb59a93fe4698 Binary files /dev/null and b/example_images/kettle.png differ diff --git a/example_images/kunkun.png b/example_images/kunkun.png new file mode 100644 index 0000000000000000000000000000000000000000..806c188eccfe9b2ad787d83ec979eee09a813127 Binary files /dev/null and b/example_images/kunkun.png differ diff --git a/example_images/milk.png b/example_images/milk.png new file mode 100644 index 0000000000000000000000000000000000000000..fc8821b09b7e05ed225bc198a22c697a585a00c6 Binary files /dev/null and b/example_images/milk.png differ diff --git a/example_images/owl.png b/example_images/owl.png new file mode 100644 index 0000000000000000000000000000000000000000..e45915d836361924ad75e581eaedf449df7b11e8 Binary files /dev/null and b/example_images/owl.png differ diff --git a/example_images/poro.png b/example_images/poro.png new file mode 100644 index 0000000000000000000000000000000000000000..98f5e9edfe23cb4de383835268efb369121db75f Binary files /dev/null and b/example_images/poro.png differ diff --git a/example_images/pumpkin.png b/example_images/pumpkin.png new file mode 100644 index 0000000000000000000000000000000000000000..cc0f090df656f7290e968f262fc607f826d2c3e8 Binary files /dev/null and b/example_images/pumpkin.png differ diff --git a/example_images/skull.png b/example_images/skull.png new file mode 100644 index 0000000000000000000000000000000000000000..c03a28f6128884af63b478d3e7162bc3ef952f21 Binary files /dev/null and b/example_images/skull.png differ diff --git a/example_images/stone.png b/example_images/stone.png new file mode 100644 index 0000000000000000000000000000000000000000..91e2b33940e029ff29c14e7fa59a9473cee1878f Binary files /dev/null and b/example_images/stone.png differ diff --git a/example_images/teapot.png b/example_images/teapot.png new file mode 100644 index 0000000000000000000000000000000000000000..1f13a6edfe67ced810b4513117279067f0360fae Binary files /dev/null and b/example_images/teapot.png differ diff --git a/example_images/tiger-head-3d-model-obj-stl.png b/example_images/tiger-head-3d-model-obj-stl.png new file mode 100644 index 0000000000000000000000000000000000000000..009efbe7bd143682803e64b5042cd88cef555c12 Binary files /dev/null and b/example_images/tiger-head-3d-model-obj-stl.png differ diff --git a/gradio_app_mv.py b/gradio_app_mv.py new file mode 100644 index 0000000000000000000000000000000000000000..930218d453d2a70561f29a12f827a5029eb466d5 --- /dev/null +++ b/gradio_app_mv.py @@ -0,0 +1,439 @@ +import os +import torch +import fire +import gradio as gr +from PIL import Image +from functools import partial + +import cv2 +import time +import numpy as np +from rembg import remove +from segment_anything import sam_model_registry, SamPredictor + +import os +import sys +import numpy +import torch +import rembg +import threading +import urllib.request +from PIL import Image +from typing import Dict, Optional, Tuple, List +from dataclasses import dataclass +import streamlit as st +import huggingface_hub +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from mvdiffusion.models.unet_mv2d_condition import UNetMV2DConditionModel +from mvdiffusion.data.single_image_dataset import SingleImageDataset as MVDiffusionDataset +from mvdiffusion.pipelines.pipeline_mvdiffusion_image import MVDiffusionImagePipeline +from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler +from einops import rearrange +import numpy as np +import subprocess +from datetime import datetime + +def save_image(tensor): + ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + # pdb.set_trace() + im = Image.fromarray(ndarr) + return ndarr + + +def save_image_to_disk(tensor, fp): + ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + # pdb.set_trace() + im = Image.fromarray(ndarr) + im.save(fp) + return ndarr + + +def save_image_numpy(ndarr, fp): + im = Image.fromarray(ndarr) + im.save(fp) + + +weight_dtype = torch.float16 + +_TITLE = '''Wonder3D: Single Image to 3D using Cross-Domain Diffusion''' +_DESCRIPTION = ''' +
+Generate consistent multi-view normals maps and color images. + +
+
+The demo does not include the mesh reconstruction part, please visit our github repo to get a textured mesh. +
+''' +_GPU_ID = 0 + + +if not hasattr(Image, 'Resampling'): + Image.Resampling = Image + + +def sam_init(): + sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_pt", "sam_vit_h_4b8939.pth") + model_type = "vit_h" + + sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda:{_GPU_ID}") + predictor = SamPredictor(sam) + return predictor + + +def sam_segment(predictor, input_image, *bbox_coords): + bbox = np.array(bbox_coords) + image = np.asarray(input_image) + + start_time = time.time() + predictor.set_image(image) + + masks_bbox, scores_bbox, logits_bbox = predictor.predict(box=bbox, multimask_output=True) + + print(f"SAM Time: {time.time() - start_time:.3f}s") + out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8) + out_image[:, :, :3] = image + out_image_bbox = out_image.copy() + out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 + torch.cuda.empty_cache() + return Image.fromarray(out_image_bbox, mode='RGBA') + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def preprocess(predictor, input_image, chk_group=None, segment=True, rescale=False): + RES = 1024 + input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS) + if chk_group is not None: + segment = "Background Removal" in chk_group + rescale = "Rescale" in chk_group + if segment: + image_rem = input_image.convert('RGBA') + image_nobg = remove(image_rem, alpha_matting=True) + arr = np.asarray(image_nobg)[:, :, -1] + x_nonzero = np.nonzero(arr.sum(axis=0)) + y_nonzero = np.nonzero(arr.sum(axis=1)) + x_min = int(x_nonzero[0].min()) + y_min = int(y_nonzero[0].min()) + x_max = int(x_nonzero[0].max()) + y_max = int(y_nonzero[0].max()) + input_image = sam_segment(predictor, input_image.convert('RGB'), x_min, y_min, x_max, y_max) + # Rescale and recenter + if rescale: + image_arr = np.array(input_image) + in_w, in_h = image_arr.shape[:2] + out_res = min(RES, max(in_w, in_h)) + ret, mask = cv2.threshold(np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY) + x, y, w, h = cv2.boundingRect(mask) + max_size = max(w, h) + ratio = 0.75 + side_len = int(max_size / ratio) + padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) + center = side_len // 2 + padded_image[center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w] = image_arr[y : y + h, x : x + w] + rgba = Image.fromarray(padded_image).resize((out_res, out_res), Image.LANCZOS) + + rgba_arr = np.array(rgba) / 255.0 + rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:]) + input_image = Image.fromarray((rgb * 255).astype(np.uint8)) + else: + input_image = expand2square(input_image, (127, 127, 127, 0)) + return input_image, input_image.resize((320, 320), Image.Resampling.LANCZOS) + + +def load_wonder3d_pipeline(cfg): + + pipeline = MVDiffusionImagePipeline.from_pretrained( + cfg.pretrained_model_name_or_path, + torch_dtype=weight_dtype + ) + + # pipeline.to('cuda:0') + pipeline.unet.enable_xformers_memory_efficient_attention() + + + if torch.cuda.is_available(): + pipeline.to('cuda:0') + # sys.main_lock = threading.Lock() + return pipeline + + +from mvdiffusion.data.single_image_dataset import SingleImageDataset + + +def prepare_data(single_image, crop_size): + dataset = SingleImageDataset(root_dir='', num_views=6, img_wh=[256, 256], bg_color='white', crop_size=crop_size, single_image=single_image) + return dataset[0] + +scene = 'scene' + +def run_pipeline(pipeline, cfg, single_image, guidance_scale, steps, seed, crop_size, chk_group=None): + import pdb + global scene + # pdb.set_trace() + + if chk_group is not None: + write_image = "Write Results" in chk_group + + batch = prepare_data(single_image, crop_size) + + pipeline.set_progress_bar_config(disable=True) + seed = int(seed) + generator = torch.Generator(device=pipeline.unet.device).manual_seed(seed) + + # repeat (2B, Nv, 3, H, W) + imgs_in = torch.cat([batch['imgs_in']] * 2, dim=0).to(weight_dtype) + + # (2B, Nv, Nce) + camera_embeddings = torch.cat([batch['camera_embeddings']] * 2, dim=0).to(weight_dtype) + + task_embeddings = torch.cat([batch['normal_task_embeddings'], batch['color_task_embeddings']], dim=0).to(weight_dtype) + + camera_embeddings = torch.cat([camera_embeddings, task_embeddings], dim=-1).to(weight_dtype) + + # (B*Nv, 3, H, W) + imgs_in = rearrange(imgs_in, "Nv C H W -> (Nv) C H W") + # (B*Nv, Nce) + # camera_embeddings = rearrange(camera_embeddings, "B Nv Nce -> (B Nv) Nce") + + out = pipeline( + imgs_in, + # camera_embeddings, + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=steps, + output_type='pt', + num_images_per_prompt=1, + **cfg.pipe_validation_kwargs, + ).images + + bsz = out.shape[0] // 2 + normals_pred = out[:bsz] + images_pred = out[bsz:] + num_views = 6 + if write_image: + VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left'] + cur_dir = os.path.join("./outputs", f"cropsize-{int(crop_size)}-cfg{guidance_scale:.1f}") + + scene = 'scene'+datetime.now().strftime('@%Y%m%d-%H%M%S') + scene_dir = os.path.join(cur_dir, scene) + normal_dir = os.path.join(scene_dir, "normals") + masked_colors_dir = os.path.join(scene_dir, "masked_colors") + os.makedirs(normal_dir, exist_ok=True) + os.makedirs(masked_colors_dir, exist_ok=True) + for j in range(num_views): + view = VIEWS[j] + normal = normals_pred[j] + color = images_pred[j] + + normal_filename = f"normals_000_{view}.png" + rgb_filename = f"rgb_000_{view}.png" + normal = save_image_to_disk(normal, os.path.join(normal_dir, normal_filename)) + color = save_image_to_disk(color, os.path.join(scene_dir, rgb_filename)) + + # rm_normal = remove(normal) + # rm_color = remove(color) + + # save_image_numpy(rm_normal, os.path.join(scene_dir, normal_filename)) + # save_image_numpy(rm_color, os.path.join(masked_colors_dir, rgb_filename)) + + normals_pred = [save_image(normals_pred[i]) for i in range(bsz)] + images_pred = [save_image(images_pred[i]) for i in range(bsz)] + + out = images_pred + normals_pred + return out + + +def process_3d(mode, data_dir, guidance_scale, crop_size): + dir = None + global scene + + cur_dir = os.path.dirname(os.path.abspath(__file__)) + + subprocess.run( + f'cd instant-nsr-pl && python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=../{data_dir}/cropsize-{crop_size:.1f}-cfg{guidance_scale:.1f}/ dataset.scene={scene} && cd ..', + shell=True, + ) + import glob + # import pdb + + # pdb.set_trace() + + obj_files = glob.glob(f'{cur_dir}/instant-nsr-pl/exp/{scene}/*/save/*.obj', recursive=True) + print(obj_files) + if obj_files: + dir = obj_files[0] + return dir + + +@dataclass +class TestConfig: + pretrained_model_name_or_path: str + pretrained_unet_path: str + revision: Optional[str] + validation_dataset: Dict + save_dir: str + seed: Optional[int] + validation_batch_size: int + dataloader_num_workers: int + + local_rank: int + + pipe_kwargs: Dict + pipe_validation_kwargs: Dict + unet_from_pretrained_kwargs: Dict + validation_guidance_scales: List[float] + validation_grid_nrow: int + camera_embedding_lr_mult: float + + num_views: int + camera_embedding_type: str + + pred_type: str # joint, or ablation + + enable_xformers_memory_efficient_attention: bool + + cond_on_normals: bool + cond_on_colors: bool + + +def run_demo(): + from utils.misc import load_config + from omegaconf import OmegaConf + + # parse YAML config to OmegaConf + cfg = load_config("./configs/mvdiffusion-joint-ortho-6views.yaml") + # print(cfg) + schema = OmegaConf.structured(TestConfig) + cfg = OmegaConf.merge(schema, cfg) + + pipeline = load_wonder3d_pipeline(cfg) + torch.set_grad_enabled(False) + pipeline.to(f'cuda:{_GPU_ID}') + + predictor = sam_init() + + custom_theme = gr.themes.Soft(primary_hue="blue").set( + button_secondary_background_fill="*neutral_100", button_secondary_background_fill_hover="*neutral_200" + ) + custom_css = '''#disp_image { + text-align: center; /* Horizontally center the content */ + }''' + + with gr.Blocks(title=_TITLE, theme=custom_theme, css=custom_css) as demo: + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown('# ' + _TITLE) + gr.Markdown(_DESCRIPTION) + with gr.Row(variant='panel'): + with gr.Column(scale=1): + input_image = gr.Image(type='pil', image_mode='RGBA', height=320, label='Input image', tool=None) + + with gr.Column(scale=1): + processed_image = gr.Image( + type='pil', + label="Processed Image", + interactive=False, + height=320, + tool=None, + image_mode='RGBA', + elem_id="disp_image", + visible=True, + ) + # with gr.Column(scale=1): + # ## add 3D Model + # obj_3d = gr.Model3D( + # # clear_color=[0.0, 0.0, 0.0, 0.0], + # label="3D Model", height=320, + # # camera_position=[0,0,2.0] + # ) + processed_image_highres = gr.Image(type='pil', image_mode='RGBA', visible=False, tool=None) + with gr.Row(variant='panel'): + with gr.Column(scale=1): + example_folder = os.path.join(os.path.dirname(__file__), "./example_images") + example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)] + gr.Examples( + examples=example_fns, + inputs=[input_image], + outputs=[input_image], + cache_examples=False, + label='Examples (click one of the images below to start)', + examples_per_page=30, + ) + with gr.Column(scale=1): + with gr.Accordion('Advanced options', open=True): + with gr.Row(): + with gr.Column(): + input_processing = gr.CheckboxGroup( + ['Background Removal'], + label='Input Image Preprocessing', + value=['Background Removal'], + info='untick this, if masked image with alpha channel', + ) + with gr.Column(): + output_processing = gr.CheckboxGroup( + ['Write Results'], label='write the results in ./outputs folder', value=['Write Results'] + ) + with gr.Row(): + with gr.Column(): + scale_slider = gr.Slider(1, 5, value=1, step=1, label='Classifier Free Guidance Scale') + with gr.Column(): + steps_slider = gr.Slider(15, 100, value=50, step=1, label='Number of Diffusion Inference Steps') + with gr.Row(): + with gr.Column(): + seed = gr.Number(42, label='Seed') + with gr.Column(): + crop_size = gr.Number(192, label='Crop size') + + mode = gr.Textbox('train', visible=False) + data_dir = gr.Textbox('outputs', visible=False) + # crop_size = 192 + # with gr.Row(): + # method = gr.Radio(choices=['instant-nsr-pl', 'NeuS'], label='Method (Default: instant-nsr-pl)', value='instant-nsr-pl') + run_btn = gr.Button('Generate Normals and Colors', variant='primary', interactive=True) + # recon_btn = gr.Button('Reconstruct 3D model', variant='primary', interactive=True) + # gr.Markdown("First click Generate button, then click Reconstruct button. Reconstruction may cost several minutes.") + + with gr.Row(): + view_1 = gr.Image(interactive=False, height=240, show_label=False) + view_2 = gr.Image(interactive=False, height=240, show_label=False) + view_3 = gr.Image(interactive=False, height=240, show_label=False) + view_4 = gr.Image(interactive=False, height=240, show_label=False) + view_5 = gr.Image(interactive=False, height=240, show_label=False) + view_6 = gr.Image(interactive=False, height=240, show_label=False) + with gr.Row(): + normal_1 = gr.Image(interactive=False, height=240, show_label=False) + normal_2 = gr.Image(interactive=False, height=240, show_label=False) + normal_3 = gr.Image(interactive=False, height=240, show_label=False) + normal_4 = gr.Image(interactive=False, height=240, show_label=False) + normal_5 = gr.Image(interactive=False, height=240, show_label=False) + normal_6 = gr.Image(interactive=False, height=240, show_label=False) + + run_btn.click( + fn=partial(preprocess, predictor), inputs=[input_image, input_processing], outputs=[processed_image_highres, processed_image], queue=True + ).success( + fn=partial(run_pipeline, pipeline, cfg), + inputs=[processed_image_highres, scale_slider, steps_slider, seed, crop_size, output_processing], + outputs=[view_1, view_2, view_3, view_4, view_5, view_6, normal_1, normal_2, normal_3, normal_4, normal_5, normal_6], + ) + # recon_btn.click( + # process_3d, inputs=[mode, data_dir, scale_slider, crop_size], outputs=[obj_3d] + # ) + + demo.queue().launch(share=True, max_threads=80) + + +if __name__ == '__main__': + fire.Fire(run_demo) diff --git a/gradio_app_recon.py b/gradio_app_recon.py new file mode 100644 index 0000000000000000000000000000000000000000..4153ae8f33530b2d82572cd26c610e6449f16555 --- /dev/null +++ b/gradio_app_recon.py @@ -0,0 +1,438 @@ +import os +import torch +import fire +import gradio as gr +from PIL import Image +from functools import partial + +import cv2 +import time +import numpy as np +from rembg import remove +from segment_anything import sam_model_registry, SamPredictor + +import os +import sys +import numpy +import torch +import rembg +import threading +import urllib.request +from PIL import Image +from typing import Dict, Optional, Tuple, List +from dataclasses import dataclass +import streamlit as st +import huggingface_hub +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from mvdiffusion.models.unet_mv2d_condition import UNetMV2DConditionModel +from mvdiffusion.data.single_image_dataset import SingleImageDataset as MVDiffusionDataset +from mvdiffusion.pipelines.pipeline_mvdiffusion_image import MVDiffusionImagePipeline +from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler +from einops import rearrange +import numpy as np +import subprocess +from datetime import datetime + +def save_image(tensor): + ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + # pdb.set_trace() + im = Image.fromarray(ndarr) + return ndarr + + +def save_image_to_disk(tensor, fp): + ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + # pdb.set_trace() + im = Image.fromarray(ndarr) + im.save(fp) + return ndarr + + +def save_image_numpy(ndarr, fp): + im = Image.fromarray(ndarr) + im.save(fp) + + +weight_dtype = torch.float16 + +_TITLE = '''Wonder3D: Single Image to 3D using Cross-Domain Diffusion''' +_DESCRIPTION = ''' +
+Generate consistent multi-view normals maps and color images. + +
+
+The demo does not include the mesh reconstruction part, please visit our github repo to get a textured mesh. +
+''' +_GPU_ID = 0 + + +if not hasattr(Image, 'Resampling'): + Image.Resampling = Image + + +def sam_init(): + sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_pt", "sam_vit_h_4b8939.pth") + model_type = "vit_h" + + sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda:{_GPU_ID}") + predictor = SamPredictor(sam) + return predictor + + +def sam_segment(predictor, input_image, *bbox_coords): + bbox = np.array(bbox_coords) + image = np.asarray(input_image) + + start_time = time.time() + predictor.set_image(image) + + masks_bbox, scores_bbox, logits_bbox = predictor.predict(box=bbox, multimask_output=True) + + print(f"SAM Time: {time.time() - start_time:.3f}s") + out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8) + out_image[:, :, :3] = image + out_image_bbox = out_image.copy() + out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 + torch.cuda.empty_cache() + return Image.fromarray(out_image_bbox, mode='RGBA') + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def preprocess(predictor, input_image, chk_group=None, segment=True, rescale=False): + RES = 1024 + input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS) + if chk_group is not None: + segment = "Background Removal" in chk_group + rescale = "Rescale" in chk_group + if segment: + image_rem = input_image.convert('RGBA') + image_nobg = remove(image_rem, alpha_matting=True) + arr = np.asarray(image_nobg)[:, :, -1] + x_nonzero = np.nonzero(arr.sum(axis=0)) + y_nonzero = np.nonzero(arr.sum(axis=1)) + x_min = int(x_nonzero[0].min()) + y_min = int(y_nonzero[0].min()) + x_max = int(x_nonzero[0].max()) + y_max = int(y_nonzero[0].max()) + input_image = sam_segment(predictor, input_image.convert('RGB'), x_min, y_min, x_max, y_max) + # Rescale and recenter + if rescale: + image_arr = np.array(input_image) + in_w, in_h = image_arr.shape[:2] + out_res = min(RES, max(in_w, in_h)) + ret, mask = cv2.threshold(np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY) + x, y, w, h = cv2.boundingRect(mask) + max_size = max(w, h) + ratio = 0.75 + side_len = int(max_size / ratio) + padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) + center = side_len // 2 + padded_image[center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w] = image_arr[y : y + h, x : x + w] + rgba = Image.fromarray(padded_image).resize((out_res, out_res), Image.LANCZOS) + + rgba_arr = np.array(rgba) / 255.0 + rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:]) + input_image = Image.fromarray((rgb * 255).astype(np.uint8)) + else: + input_image = expand2square(input_image, (127, 127, 127, 0)) + return input_image, input_image.resize((320, 320), Image.Resampling.LANCZOS) + + +def load_wonder3d_pipeline(cfg): + + pipeline = MVDiffusionImagePipeline.from_pretrained( + cfg.pretrained_model_name_or_path, + torch_dtype=weight_dtype + ) + + # pipeline.to('cuda:0') + pipeline.unet.enable_xformers_memory_efficient_attention() + + + if torch.cuda.is_available(): + pipeline.to('cuda:0') + # sys.main_lock = threading.Lock() + return pipeline + + +from mvdiffusion.data.single_image_dataset import SingleImageDataset + + +def prepare_data(single_image, crop_size): + dataset = SingleImageDataset(root_dir='', num_views=6, img_wh=[256, 256], bg_color='white', crop_size=crop_size, single_image=single_image) + return dataset[0] + +scene = 'scene' + +def run_pipeline(pipeline, cfg, single_image, guidance_scale, steps, seed, crop_size, chk_group=None): + import pdb + global scene + # pdb.set_trace() + + if chk_group is not None: + write_image = "Write Results" in chk_group + + batch = prepare_data(single_image, crop_size) + + pipeline.set_progress_bar_config(disable=True) + seed = int(seed) + generator = torch.Generator(device=pipeline.unet.device).manual_seed(seed) + + # repeat (2B, Nv, 3, H, W) + imgs_in = torch.cat([batch['imgs_in']] * 2, dim=0).to(weight_dtype) + + # (2B, Nv, Nce) + camera_embeddings = torch.cat([batch['camera_embeddings']] * 2, dim=0).to(weight_dtype) + + task_embeddings = torch.cat([batch['normal_task_embeddings'], batch['color_task_embeddings']], dim=0).to(weight_dtype) + + camera_embeddings = torch.cat([camera_embeddings, task_embeddings], dim=-1).to(weight_dtype) + + # (B*Nv, 3, H, W) + imgs_in = rearrange(imgs_in, "Nv C H W -> (Nv) C H W") + # (B*Nv, Nce) + # camera_embeddings = rearrange(camera_embeddings, "B Nv Nce -> (B Nv) Nce") + + out = pipeline( + imgs_in, + camera_embeddings, + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=steps, + output_type='pt', + num_images_per_prompt=1, + **cfg.pipe_validation_kwargs, + ).images + + bsz = out.shape[0] // 2 + normals_pred = out[:bsz] + images_pred = out[bsz:] + num_views = 6 + if write_image: + VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left'] + cur_dir = os.path.join("./outputs", f"cropsize-{int(crop_size)}-cfg{guidance_scale:.1f}") + + scene = 'scene'+datetime.now().strftime('@%Y%m%d-%H%M%S') + scene_dir = os.path.join(cur_dir, scene) + normal_dir = os.path.join(scene_dir, "normals") + masked_colors_dir = os.path.join(scene_dir, "masked_colors") + os.makedirs(normal_dir, exist_ok=True) + os.makedirs(masked_colors_dir, exist_ok=True) + for j in range(num_views): + view = VIEWS[j] + normal = normals_pred[j] + color = images_pred[j] + + normal_filename = f"normals_000_{view}.png" + rgb_filename = f"rgb_000_{view}.png" + normal = save_image_to_disk(normal, os.path.join(normal_dir, normal_filename)) + color = save_image_to_disk(color, os.path.join(scene_dir, rgb_filename)) + + rm_normal = remove(normal) + rm_color = remove(color) + + save_image_numpy(rm_normal, os.path.join(scene_dir, normal_filename)) + save_image_numpy(rm_color, os.path.join(masked_colors_dir, rgb_filename)) + + normals_pred = [save_image(normals_pred[i]) for i in range(bsz)] + images_pred = [save_image(images_pred[i]) for i in range(bsz)] + + out = images_pred + normals_pred + return out + + +def process_3d(mode, data_dir, guidance_scale, crop_size): + dir = None + global scene + + cur_dir = os.path.dirname(os.path.abspath(__file__)) + + subprocess.run( + f'cd instant-nsr-pl && python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=../{data_dir}/cropsize-{int(crop_size)}-cfg{guidance_scale:.1f}/ dataset.scene={scene} && cd ..', + shell=True, + ) + import glob + # import pdb + + # pdb.set_trace() + + obj_files = glob.glob(f'{cur_dir}/instant-nsr-pl/exp/{scene}/*/save/*.obj', recursive=True) + print(obj_files) + if obj_files: + dir = obj_files[0] + return dir + + +@dataclass +class TestConfig: + pretrained_model_name_or_path: str + pretrained_unet_path: str + revision: Optional[str] + validation_dataset: Dict + save_dir: str + seed: Optional[int] + validation_batch_size: int + dataloader_num_workers: int + + local_rank: int + + pipe_kwargs: Dict + pipe_validation_kwargs: Dict + unet_from_pretrained_kwargs: Dict + validation_guidance_scales: List[float] + validation_grid_nrow: int + camera_embedding_lr_mult: float + + num_views: int + camera_embedding_type: str + + pred_type: str # joint, or ablation + + enable_xformers_memory_efficient_attention: bool + + cond_on_normals: bool + cond_on_colors: bool + + +def run_demo(): + from utils.misc import load_config + from omegaconf import OmegaConf + + # parse YAML config to OmegaConf + cfg = load_config("./configs/mvdiffusion-joint-ortho-6views.yaml") + # print(cfg) + schema = OmegaConf.structured(TestConfig) + cfg = OmegaConf.merge(schema, cfg) + + pipeline = load_wonder3d_pipeline(cfg) + torch.set_grad_enabled(False) + pipeline.to(f'cuda:{_GPU_ID}') + + predictor = sam_init() + + custom_theme = gr.themes.Soft(primary_hue="blue").set( + button_secondary_background_fill="*neutral_100", button_secondary_background_fill_hover="*neutral_200" + ) + custom_css = '''#disp_image { + text-align: center; /* Horizontally center the content */ + }''' + + with gr.Blocks(title=_TITLE, theme=custom_theme, css=custom_css) as demo: + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown('# ' + _TITLE) + gr.Markdown(_DESCRIPTION) + with gr.Row(variant='panel'): + with gr.Column(scale=1): + input_image = gr.Image(type='pil', image_mode='RGBA', height=320, label='Input image', tool=None) + + with gr.Column(scale=1): + processed_image = gr.Image( + type='pil', + label="Processed Image", + interactive=False, + height=320, + tool=None, + image_mode='RGBA', + elem_id="disp_image", + visible=True, + ) + with gr.Column(scale=1): + ## add 3D Model + obj_3d = gr.Model3D( + # clear_color=[0.0, 0.0, 0.0, 0.0], + label="3D Model", height=320, + # camera_position=[0,0,2.0] + ) + processed_image_highres = gr.Image(type='pil', image_mode='RGBA', visible=False, tool=None) + with gr.Row(variant='panel'): + with gr.Column(scale=1): + example_folder = os.path.join(os.path.dirname(__file__), "./example_images") + example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)] + gr.Examples( + examples=example_fns, + inputs=[input_image], + outputs=[input_image], + cache_examples=False, + label='Examples (click one of the images below to start)', + examples_per_page=30, + ) + with gr.Column(scale=1): + with gr.Accordion('Advanced options', open=True): + with gr.Row(): + with gr.Column(): + input_processing = gr.CheckboxGroup( + ['Background Removal'], + label='Input Image Preprocessing', + value=['Background Removal'], + info='untick this, if masked image with alpha channel', + ) + with gr.Column(): + output_processing = gr.CheckboxGroup( + ['Write Results'], label='write the results in ./outputs folder', value=['Write Results'] + ) + with gr.Row(): + with gr.Column(): + scale_slider = gr.Slider(1, 5, value=1, step=1, label='Classifier Free Guidance Scale') + with gr.Column(): + steps_slider = gr.Slider(15, 100, value=50, step=1, label='Number of Diffusion Inference Steps') + with gr.Row(): + with gr.Column(): + seed = gr.Number(42, label='Seed') + with gr.Column(): + crop_size = gr.Number(192, label='Crop size') + + mode = gr.Textbox('train', visible=False) + data_dir = gr.Textbox('outputs', visible=False) + # crop_size = 192 + # with gr.Row(): + # method = gr.Radio(choices=['instant-nsr-pl', 'NeuS'], label='Method (Default: instant-nsr-pl)', value='instant-nsr-pl') + # run_btn = gr.Button('Generate Normals and Colors', variant='primary', interactive=True) + run_btn = gr.Button('Reconstruct 3D model', variant='primary', interactive=True) + gr.Markdown(" Reconstruction may cost several minutes. Check results in instant-nsr-pl/exp/scene@{current-time}/ ") + + with gr.Row(): + view_1 = gr.Image(interactive=False, height=240, show_label=False) + view_2 = gr.Image(interactive=False, height=240, show_label=False) + view_3 = gr.Image(interactive=False, height=240, show_label=False) + view_4 = gr.Image(interactive=False, height=240, show_label=False) + view_5 = gr.Image(interactive=False, height=240, show_label=False) + view_6 = gr.Image(interactive=False, height=240, show_label=False) + with gr.Row(): + normal_1 = gr.Image(interactive=False, height=240, show_label=False) + normal_2 = gr.Image(interactive=False, height=240, show_label=False) + normal_3 = gr.Image(interactive=False, height=240, show_label=False) + normal_4 = gr.Image(interactive=False, height=240, show_label=False) + normal_5 = gr.Image(interactive=False, height=240, show_label=False) + normal_6 = gr.Image(interactive=False, height=240, show_label=False) + + run_btn.click( + fn=partial(preprocess, predictor), inputs=[input_image, input_processing], outputs=[processed_image_highres, processed_image], queue=True + ).success( + fn=partial(run_pipeline, pipeline, cfg), + inputs=[processed_image_highres, scale_slider, steps_slider, seed, crop_size, output_processing], + outputs=[view_1, view_2, view_3, view_4, view_5, view_6, normal_1, normal_2, normal_3, normal_4, normal_5, normal_6], + ).success( + process_3d, inputs=[mode, data_dir, scale_slider, crop_size], outputs=[obj_3d] + ) + + demo.queue().launch(share=True, max_threads=80) + + +if __name__ == '__main__': + fire.Fire(run_demo) diff --git a/instant-nsr-pl/README.md b/instant-nsr-pl/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7964ddfe4641028c4f92566532ea53f0ec20f30d --- /dev/null +++ b/instant-nsr-pl/README.md @@ -0,0 +1,122 @@ +# Instant Neural Surface Reconstruction + +This repository contains a concise and extensible implementation of NeRF and NeuS for neural surface reconstruction based on Instant-NGP and the Pytorch-Lightning framework. **Training on a NeRF-Synthetic scene takes ~5min for NeRF and ~10min for NeuS on a single RTX3090.** + +||NeRF in 5min|NeuS in 10 min| +|---|---|---| +|Rendering|![rendering-nerf](https://user-images.githubusercontent.com/19284678/199078178-b719676b-7e60-47f1-813b-c0b533f5480d.png)|![rendering-neus](https://user-images.githubusercontent.com/19284678/199078300-ebcf249d-b05e-431f-b035-da354705d8db.png)| +|Mesh|![mesh-nerf](https://user-images.githubusercontent.com/19284678/199078661-b5cd569a-c22b-4220-9c11-d5fd13a52fb8.png)|![mesh-neus](https://user-images.githubusercontent.com/19284678/199078481-164e36a6-6d55-45cc-aaf3-795a114e4a38.png)| + + +## Features +**This repository aims to provide a highly efficient while customizable boilerplate for research projects based on NeRF or NeuS.** + +- acceleration techniques from [Instant-NGP](https://github.com/NVlabs/instant-ngp): multiresolution hash encoding and fully fused networks by [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn), occupancy grid pruning and rendering by [nerfacc](https://github.com/KAIR-BAIR/nerfacc) +- out-of-the-box multi-GPU and mixed precision training by [PyTorch-Lightning](https://github.com/Lightning-AI/lightning) +- hierarchical project layout that is designed to be easily customized and extended, flexible experiment configuration by [OmegaConf](https://github.com/omry/omegaconf) + +**Please subscribe to [#26](https://github.com/bennyguo/instant-nsr-pl/issues/26) for our latest findings on quality improvements!** + +## News + +🔥🔥🔥 Check out my new project on 3D content generation: https://github.com/threestudio-project/threestudio 🔥🔥🔥 + +- 06/03/2023: Add an implementation of [Neuralangelo](https://research.nvidia.com/labs/dir/neuralangelo/). See [here](https://github.com/bennyguo/instant-nsr-pl#training-on-DTU) for details. +- 03/31/2023: NeuS model now supports background modeling. You could try on the DTU dataset provided by [NeuS](https://drive.google.com/drive/folders/1Nlzejs4mfPuJYORLbDEUDWlc9IZIbU0C?usp=sharing) or [IDR](https://www.dropbox.com/sh/5tam07ai8ch90pf/AADniBT3dmAexvm_J1oL__uoa) following [the instruction here](https://github.com/bennyguo/instant-nsr-pl#training-on-DTU). +- 02/11/2023: NeRF model now supports unbounded 360 scenes with learned background. You could try on [MipNeRF 360 data](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip) following [the COLMAP configuration](https://github.com/bennyguo/instant-nsr-pl#training-on-custom-colmap-data). + +## Requirements +**Note:** +- To utilize multiresolution hash encoding or fully fused networks provided by tiny-cuda-nn, you should have least an RTX 2080Ti, see [https://github.com/NVlabs/tiny-cuda-nn#requirements](https://github.com/NVlabs/tiny-cuda-nn#requirements) for more details. +- Multi-GPU training is currently not supported on Windows (see [#4](https://github.com/bennyguo/instant-nsr-pl/issues/4)). +### Environments +- Install PyTorch>=1.10 [here](https://pytorch.org/get-started/locally/) based the package management tool you used and your cuda version (older PyTorch versions may work but have not been tested) +- Install tiny-cuda-nn PyTorch extension: `pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch` +- `pip install -r requirements.txt` + + +## Run +### Training on NeRF-Synthetic +Download the NeRF-Synthetic data [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) and put it under `load/`. The file structure should be like `load/nerf_synthetic/lego`. + +Run the launch script with `--train`, specifying the config file, the GPU(s) to be used (GPU 0 will be used by default), and the scene name: +```bash +# train NeRF +python launch.py --config configs/nerf-blender.yaml --gpu 0 --train dataset.scene=lego tag=example + +# train NeuS with mask +python launch.py --config configs/neus-blender.yaml --gpu 0 --train dataset.scene=lego tag=example +# train NeuS without mask +python launch.py --config configs/neus-blender.yaml --gpu 0 --train dataset.scene=lego tag=example system.loss.lambda_mask=0.0 +``` +The code snapshots, checkpoints and experiment outputs are saved to `exp/[name]/[tag]@[timestamp]`, and tensorboard logs can be found at `runs/[name]/[tag]@[timestamp]`. You can change any configuration in the YAML file by specifying arguments without `--`, for example: +```bash +python launch.py --config configs/nerf-blender.yaml --gpu 0 --train dataset.scene=lego tag=iter50k seed=0 trainer.max_steps=50000 +``` +### Training on DTU +Download preprocessed DTU data provided by [NeuS](https://drive.google.com/drive/folders/1Nlzejs4mfPuJYORLbDEUDWlc9IZIbU0C?usp=sharing) or [IDR](https://www.dropbox.com/sh/5tam07ai8ch90pf/AADniBT3dmAexvm_J1oL__uoa). In the provided config files we assume using NeuS DTU data. If you are using IDR DTU data, please set `dataset.cameras_file=cameras.npz`. You may also need to adjust `dataset.root_dir` to point to your downloaded data location. +```bash +# train NeuS on DTU without mask +python launch.py --config configs/neus-dtu.yaml --gpu 0 --train +# train NeuS on DTU with mask +python launch.py --config configs/neus-dtu-wmask.yaml --gpu 0 --train +# train NeuS on DTU with mask using tricks from Neuralangelo (experimental) +python launch.py --config configs/neuralangelo-dtu-wmask.yaml --gpu 0 --train +``` +Notes: +- PSNR in the testing stage is meaningless, as we simply compare to pure white images in testing. +- The results of Neuralangelo can't reach those in the original paper. Some potential improvements: more iterations; larger `system.geometry.xyz_encoding_config.update_steps`; larger `system.geometry.xyz_encoding_config.n_features_per_level`; larger `system.geometry.xyz_encoding_config.log2_hashmap_size`; adopting curvature loss. + +### Training on Custom COLMAP Data +To get COLMAP data from custom images, you should have COLMAP installed (see [here](https://colmap.github.io/install.html) for installation instructions). Then put your images in the `images/` folder, and run `scripts/imgs2poses.py` specifying the path containing the `images/` folder. For example: +```bash +python scripts/imgs2poses.py ./load/bmvs_dog # images are in ./load/bmvs_dog/images +``` +Existing data following this file structure also works as long as images are store in `images/` and there is a `sparse/` folder for the COLMAP output, for example [the data provided by MipNeRF 360](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip). An optional `masks/` folder could be provided for object mask supervision. To train on COLMAP data, please refer to the example config files `config/*-colmap.yaml`. Some notes: +- Adapt the `root_dir` and `img_wh` (or `img_downscale`) option in the config file to your data; +- The scene is normalized so that cameras have a minimum distance `1.0` to the center of the scene. Setting `model.radius=1.0` works in most cases. If not, try setting a smaller radius that wraps tightly to your foreground object. +- There are three choices to determine the scene center: `dataset.center_est_method=camera` uses the center of all camera positions as the scene center; `dataset.center_est_method=lookat` assumes the cameras are looking at the same point and calculates an approximate look-at point as the scene center; `dataset.center_est_method=point` uses the center of all points (reconstructed by COLMAP) that are bounded by cameras as the scene center. Please choose an appropriate method according to your capture. +- PSNR in the testing stage is meaningless, as we simply compare to pure white images in testing. + +### Testing +The training procedure are by default followed by testing, which computes metrics on test data, generates animations and exports the geometry as triangular meshes. If you want to do testing alone, just resume the pretrained model and replace `--train` with `--test`, for example: +```bash +python launch.py --config path/to/your/exp/config/parsed.yaml --resume path/to/your/exp/ckpt/epoch=0-step=20000.ckpt --gpu 0 --test +``` + + +## Benchmarks +All experiments are conducted on a single NVIDIA RTX3090. + +|PSNR|Chair|Drums|Ficus|Hotdog|Lego|Materials|Mic|Ship|Avg.| +|---|---|---|---|---|---|---|---|---|---| +|NeRF Paper|33.00|25.01|30.13|36.18|32.54|29.62|32.91|28.65|31.01| +|NeRF Ours (20k)|34.80|26.04|33.89|37.42|35.33|29.46|35.22|31.17|32.92| +|NeuS Ours (20k, with masks)|34.04|25.26|32.47|35.94|33.78|27.67|33.43|29.50|31.51| + +|Training Time (mm:ss)|Chair|Drums|Ficus|Hotdog|Lego|Materials|Mic|Ship|Avg.| +|---|---|---|---|---|---|---|---|---|---| +|NeRF Ours (20k)|04:34|04:35|04:18|04:46|04:39|04:35|04:26|05:41|04:42| +|NeuS Ours (20k, with masks)|11:25|10:34|09:51|12:11|11:37|11:46|09:59|16:25|11:44| + + +## TODO +- [✅] Support more dataset formats, like COLMAP outputs and DTU +- [✅] Support simple background model +- [ ] Support GUI training and interaction +- [ ] More illustrations about the framework + +## Related Projects +- [ngp_pl](https://github.com/kwea123/ngp_pl): Great Instant-NGP implementation in PyTorch-Lightning! Background model and GUI supported. +- [Instant-NSR](https://github.com/zhaofuq/Instant-NSR): NeuS implementation using multiresolution hash encoding. + +## Citation +If you find this codebase useful, please consider citing: +``` +@misc{instant-nsr-pl, + Author = {Yuan-Chen Guo}, + Year = {2022}, + Note = {https://github.com/bennyguo/instant-nsr-pl}, + Title = {Instant Neural Surface Reconstruction} +} +``` diff --git a/instant-nsr-pl/configs/neuralangelo-ortho-wmask.yaml b/instant-nsr-pl/configs/neuralangelo-ortho-wmask.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4411ac8173d87e17717ec7381165b03c5b464d7f --- /dev/null +++ b/instant-nsr-pl/configs/neuralangelo-ortho-wmask.yaml @@ -0,0 +1,145 @@ +name: ${basename:${dataset.scene}} +tag: "" +seed: 42 + +dataset: + name: ortho + root_dir: /home/xiaoxiao/Workplace/wonder3Dplus/outputs/joint-twice/aigc/cropsize-224-cfg1.0 + cam_pose_dir: null + scene: scene_name + imSize: [1024, 1024] # should use larger res, otherwise the exported mesh has wrong colors + camera_type: ortho + apply_mask: true + camera_params: null + view_weights: [1.0, 0.8, 0.2, 1.0, 0.4, 0.7] #['front', 'front_right', 'right', 'back', 'left', 'front_left'] + # view_weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + +model: + name: neus + radius: 1.0 + num_samples_per_ray: 1024 + train_num_rays: 256 + max_train_num_rays: 8192 + grid_prune: true + grid_prune_occ_thre: 0.001 + dynamic_ray_sampling: true + batch_image_sampling: true + randomized: true + ray_chunk: 2048 + cos_anneal_end: 20000 + learned_background: false + background_color: black + variance: + init_val: 0.3 + modulate: false + geometry: + name: volume-sdf + radius: ${model.radius} + feature_dim: 13 + grad_type: finite_difference + finite_difference_eps: progressive + isosurface: + method: mc + resolution: 192 + chunk: 2097152 + threshold: 0. + xyz_encoding_config: + otype: ProgressiveBandHashGrid + n_levels: 10 # 12 modify + n_features_per_level: 2 + log2_hashmap_size: 19 + base_resolution: 32 + per_level_scale: 1.3195079107728942 + include_xyz: true + start_level: 4 + start_step: 0 + update_steps: 1000 + mlp_network_config: + otype: VanillaMLP + activation: ReLU + output_activation: none + n_neurons: 64 + n_hidden_layers: 1 + sphere_init: true + sphere_init_radius: 0.5 + weight_norm: true + texture: + name: volume-radiance + input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input + dir_encoding_config: + otype: SphericalHarmonics + degree: 4 + mlp_network_config: + otype: VanillaMLP + activation: ReLU + output_activation: none + n_neurons: 64 + n_hidden_layers: 2 + color_activation: sigmoid + +system: + name: ortho-neus-system + loss: + lambda_rgb_mse: 0.5 + lambda_rgb_l1: 0. + lambda_mask: 1.0 + lambda_eikonal: 0.2 # cannot be too large, will cause holes to thin objects + lambda_normal: 1.0 # cannot be too large + lambda_3d_normal_smooth: 1.0 + # lambda_curvature: [0, 0.0, 1.e-4, 1000] # topology warmup + lambda_curvature: 0. + lambda_sparsity: 0.5 + lambda_distortion: 0.0 + lambda_distortion_bg: 0.0 + lambda_opaque: 0.0 + sparsity_scale: 100.0 + geo_aware: true + rgb_p_ratio: 0.8 + normal_p_ratio: 0.8 + mask_p_ratio: 0.9 + optimizer: + name: AdamW + args: + lr: 0.01 + betas: [0.9, 0.99] + eps: 1.e-15 + params: + geometry: + lr: 0.001 + texture: + lr: 0.01 + variance: + lr: 0.001 + constant_steps: 500 + scheduler: + name: SequentialLR + interval: step + milestones: + - ${system.constant_steps} + schedulers: + - name: ConstantLR + args: + factor: 1.0 + total_iters: ${system.constant_steps} + - name: ExponentialLR + args: + gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}} + +checkpoint: + save_top_k: -1 + every_n_train_steps: ${trainer.max_steps} + +export: + chunk_size: 2097152 + export_vertex_color: True + ortho_scale: 1.35 #modify + +trainer: + max_steps: 3000 + log_every_n_steps: 100 + num_sanity_val_steps: 0 + val_check_interval: 4000 + limit_train_batches: 1.0 + limit_val_batches: 2 + enable_progress_bar: true + precision: 16 diff --git a/instant-nsr-pl/datasets/__init__.py b/instant-nsr-pl/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7904ca9cbb8465d618ca6160a8562be0a594a0ab --- /dev/null +++ b/instant-nsr-pl/datasets/__init__.py @@ -0,0 +1,16 @@ +datasets = {} + + +def register(name): + def decorator(cls): + datasets[name] = cls + return cls + return decorator + + +def make(name, config): + dataset = datasets[name](config) + return dataset + + +from . import blender, colmap, dtu, ortho diff --git a/instant-nsr-pl/datasets/blender.py b/instant-nsr-pl/datasets/blender.py new file mode 100644 index 0000000000000000000000000000000000000000..3affa110ddc2f17db7c8678aad0980e799c83cef --- /dev/null +++ b/instant-nsr-pl/datasets/blender.py @@ -0,0 +1,135 @@ +import os +import json +import math +import numpy as np +from PIL import Image + +import torch +from torch.utils.data import Dataset, DataLoader, IterableDataset +import torchvision.transforms.functional as TF + +import pytorch_lightning as pl + +import datasets +from models.ray_utils import get_ray_directions +from utils.misc import get_rank + + +class BlenderDatasetBase(): + def setup(self, config, split): + self.config = config + self.split = split + self.rank = get_rank() + + self.has_mask = True + self.apply_mask = True + + with open(os.path.join(self.config.root_dir, f"transforms_{self.split}.json"), 'r') as f: + meta = json.load(f) + + if 'w' in meta and 'h' in meta: + W, H = int(meta['w']), int(meta['h']) + else: + W, H = 800, 800 + + if 'img_wh' in self.config: + w, h = self.config.img_wh + assert round(W / w * h) == H + elif 'img_downscale' in self.config: + w, h = W // self.config.img_downscale, H // self.config.img_downscale + else: + raise KeyError("Either img_wh or img_downscale should be specified.") + + self.w, self.h = w, h + self.img_wh = (self.w, self.h) + + self.near, self.far = self.config.near_plane, self.config.far_plane + + self.focal = 0.5 * w / math.tan(0.5 * meta['camera_angle_x']) # scaled focal length + + # ray directions for all pixels, same for all images (same H, W, focal) + self.directions = \ + get_ray_directions(self.w, self.h, self.focal, self.focal, self.w//2, self.h//2).to(self.rank) # (h, w, 3) + + self.all_c2w, self.all_images, self.all_fg_masks = [], [], [] + + for i, frame in enumerate(meta['frames']): + c2w = torch.from_numpy(np.array(frame['transform_matrix'])[:3, :4]) + self.all_c2w.append(c2w) + + img_path = os.path.join(self.config.root_dir, f"{frame['file_path']}.png") + img = Image.open(img_path) + img = img.resize(self.img_wh, Image.BICUBIC) + img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4) + + self.all_fg_masks.append(img[..., -1]) # (h, w) + self.all_images.append(img[...,:3]) + + self.all_c2w, self.all_images, self.all_fg_masks = \ + torch.stack(self.all_c2w, dim=0).float().to(self.rank), \ + torch.stack(self.all_images, dim=0).float().to(self.rank), \ + torch.stack(self.all_fg_masks, dim=0).float().to(self.rank) + + +class BlenderDataset(Dataset, BlenderDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __len__(self): + return len(self.all_images) + + def __getitem__(self, index): + return { + 'index': index + } + + +class BlenderIterableDataset(IterableDataset, BlenderDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __iter__(self): + while True: + yield {} + + +@datasets.register('blender') +class BlenderDataModule(pl.LightningDataModule): + def __init__(self, config): + super().__init__() + self.config = config + + def setup(self, stage=None): + if stage in [None, 'fit']: + self.train_dataset = BlenderIterableDataset(self.config, self.config.train_split) + if stage in [None, 'fit', 'validate']: + self.val_dataset = BlenderDataset(self.config, self.config.val_split) + if stage in [None, 'test']: + self.test_dataset = BlenderDataset(self.config, self.config.test_split) + if stage in [None, 'predict']: + self.predict_dataset = BlenderDataset(self.config, self.config.train_split) + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size): + sampler = None + return DataLoader( + dataset, + num_workers=os.cpu_count(), + batch_size=batch_size, + pin_memory=True, + sampler=sampler + ) + + def train_dataloader(self): + return self.general_loader(self.train_dataset, batch_size=1) + + def val_dataloader(self): + return self.general_loader(self.val_dataset, batch_size=1) + + def test_dataloader(self): + return self.general_loader(self.test_dataset, batch_size=1) + + def predict_dataloader(self): + return self.general_loader(self.predict_dataset, batch_size=1) diff --git a/instant-nsr-pl/datasets/colmap.py b/instant-nsr-pl/datasets/colmap.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b389ebb09b8169019046ca8afbcce872e5d30a --- /dev/null +++ b/instant-nsr-pl/datasets/colmap.py @@ -0,0 +1,332 @@ +import os +import math +import numpy as np +from PIL import Image + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader, IterableDataset +import torchvision.transforms.functional as TF + +import pytorch_lightning as pl + +import datasets +from datasets.colmap_utils import \ + read_cameras_binary, read_images_binary, read_points3d_binary +from models.ray_utils import get_ray_directions +from utils.misc import get_rank + + +def get_center(pts): + center = pts.mean(0) + dis = (pts - center[None,:]).norm(p=2, dim=-1) + mean, std = dis.mean(), dis.std() + q25, q75 = torch.quantile(dis, 0.25), torch.quantile(dis, 0.75) + valid = (dis > mean - 1.5 * std) & (dis < mean + 1.5 * std) & (dis > mean - (q75 - q25) * 1.5) & (dis < mean + (q75 - q25) * 1.5) + center = pts[valid].mean(0) + return center + +def normalize_poses(poses, pts, up_est_method, center_est_method): + if center_est_method == 'camera': + # estimation scene center as the average of all camera positions + center = poses[...,3].mean(0) + elif center_est_method == 'lookat': + # estimation scene center as the average of the intersection of selected pairs of camera rays + cams_ori = poses[...,3] + cams_dir = poses[:,:3,:3] @ torch.as_tensor([0.,0.,-1.]) + cams_dir = F.normalize(cams_dir, dim=-1) + A = torch.stack([cams_dir, -cams_dir.roll(1,0)], dim=-1) + b = -cams_ori + cams_ori.roll(1,0) + t = torch.linalg.lstsq(A, b).solution + center = (torch.stack([cams_dir, cams_dir.roll(1,0)], dim=-1) * t[:,None,:] + torch.stack([cams_ori, cams_ori.roll(1,0)], dim=-1)).mean((0,2)) + elif center_est_method == 'point': + # first estimation scene center as the average of all camera positions + # later we'll use the center of all points bounded by the cameras as the final scene center + center = poses[...,3].mean(0) + else: + raise NotImplementedError(f'Unknown center estimation method: {center_est_method}') + + if up_est_method == 'ground': + # estimate up direction as the normal of the estimated ground plane + # use RANSAC to estimate the ground plane in the point cloud + import pyransac3d as pyrsc + ground = pyrsc.Plane() + plane_eq, inliers = ground.fit(pts.numpy(), thresh=0.01) # TODO: determine thresh based on scene scale + plane_eq = torch.as_tensor(plane_eq) # A, B, C, D in Ax + By + Cz + D = 0 + z = F.normalize(plane_eq[:3], dim=-1) # plane normal as up direction + signed_distance = (torch.cat([pts, torch.ones_like(pts[...,0:1])], dim=-1) * plane_eq).sum(-1) + if signed_distance.mean() < 0: + z = -z # flip the direction if points lie under the plane + elif up_est_method == 'camera': + # estimate up direction as the average of all camera up directions + z = F.normalize((poses[...,3] - center).mean(0), dim=0) + else: + raise NotImplementedError(f'Unknown up estimation method: {up_est_method}') + + # new axis + y_ = torch.as_tensor([z[1], -z[0], 0.]) + x = F.normalize(y_.cross(z), dim=0) + y = z.cross(x) + + if center_est_method == 'point': + # rotation + Rc = torch.stack([x, y, z], dim=1) + R = Rc.T + poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1) + inv_trans = torch.cat([torch.cat([R, torch.as_tensor([[0.,0.,0.]]).T], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) + poses_norm = (inv_trans @ poses_homo)[:,:3] + pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] + + # translation and scaling + poses_min, poses_max = poses_norm[...,3].min(0)[0], poses_norm[...,3].max(0)[0] + pts_fg = pts[(poses_min[0] < pts[:,0]) & (pts[:,0] < poses_max[0]) & (poses_min[1] < pts[:,1]) & (pts[:,1] < poses_max[1])] + center = get_center(pts_fg) + tc = center.reshape(3, 1) + t = -tc + poses_homo = torch.cat([poses_norm, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses_norm.shape[0], -1, -1)], dim=1) + inv_trans = torch.cat([torch.cat([torch.eye(3), t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) + poses_norm = (inv_trans @ poses_homo)[:,:3] + scale = poses_norm[...,3].norm(p=2, dim=-1).min() + poses_norm[...,3] /= scale + pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] + pts = pts / scale + else: + # rotation and translation + Rc = torch.stack([x, y, z], dim=1) + tc = center.reshape(3, 1) + R, t = Rc.T, -Rc.T @ tc + poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1) + inv_trans = torch.cat([torch.cat([R, t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) + poses_norm = (inv_trans @ poses_homo)[:,:3] # (N_images, 4, 4) + + # scaling + scale = poses_norm[...,3].norm(p=2, dim=-1).min() + poses_norm[...,3] /= scale + + # apply the transformation to the point cloud + pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] + pts = pts / scale + + return poses_norm, pts + +def create_spheric_poses(cameras, n_steps=120): + center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device) + mean_d = (cameras - center[None,:]).norm(p=2, dim=-1).mean() + mean_h = cameras[:,2].mean() + r = (mean_d**2 - mean_h**2).sqrt() + up = torch.as_tensor([0., 0., 1.], dtype=center.dtype, device=center.device) + + all_c2w = [] + for theta in torch.linspace(0, 2 * math.pi, n_steps): + cam_pos = torch.stack([r * theta.cos(), r * theta.sin(), mean_h]) + l = F.normalize(center - cam_pos, p=2, dim=0) + s = F.normalize(l.cross(up), p=2, dim=0) + u = F.normalize(s.cross(l), p=2, dim=0) + c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1) + all_c2w.append(c2w) + + all_c2w = torch.stack(all_c2w, dim=0) + + return all_c2w + +class ColmapDatasetBase(): + # the data only has to be processed once + initialized = False + properties = {} + + def setup(self, config, split): + self.config = config + self.split = split + self.rank = get_rank() + + if not ColmapDatasetBase.initialized: + camdata = read_cameras_binary(os.path.join(self.config.root_dir, 'sparse/0/cameras.bin')) + + H = int(camdata[1].height) + W = int(camdata[1].width) + + if 'img_wh' in self.config: + w, h = self.config.img_wh + assert round(W / w * h) == H + elif 'img_downscale' in self.config: + w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5) + else: + raise KeyError("Either img_wh or img_downscale should be specified.") + + img_wh = (w, h) + factor = w / W + + if camdata[1].model == 'SIMPLE_RADIAL': + fx = fy = camdata[1].params[0] * factor + cx = camdata[1].params[1] * factor + cy = camdata[1].params[2] * factor + elif camdata[1].model in ['PINHOLE', 'OPENCV']: + fx = camdata[1].params[0] * factor + fy = camdata[1].params[1] * factor + cx = camdata[1].params[2] * factor + cy = camdata[1].params[3] * factor + else: + raise ValueError(f"Please parse the intrinsics for camera model {camdata[1].model}!") + + directions = get_ray_directions(w, h, fx, fy, cx, cy).to(self.rank) + + imdata = read_images_binary(os.path.join(self.config.root_dir, 'sparse/0/images.bin')) + + mask_dir = os.path.join(self.config.root_dir, 'masks') + has_mask = os.path.exists(mask_dir) # TODO: support partial masks + apply_mask = has_mask and self.config.apply_mask + + all_c2w, all_images, all_fg_masks = [], [], [] + + for i, d in enumerate(imdata.values()): + R = d.qvec2rotmat() + t = d.tvec.reshape(3, 1) + c2w = torch.from_numpy(np.concatenate([R.T, -R.T@t], axis=1)).float() + c2w[:,1:3] *= -1. # COLMAP => OpenGL + all_c2w.append(c2w) + if self.split in ['train', 'val']: + img_path = os.path.join(self.config.root_dir, 'images', d.name) + img = Image.open(img_path) + img = img.resize(img_wh, Image.BICUBIC) + img = TF.to_tensor(img).permute(1, 2, 0)[...,:3] + img = img.to(self.rank) if self.config.load_data_on_gpu else img.cpu() + if has_mask: + mask_paths = [os.path.join(mask_dir, d.name), os.path.join(mask_dir, d.name[3:])] + mask_paths = list(filter(os.path.exists, mask_paths)) + assert len(mask_paths) == 1 + mask = Image.open(mask_paths[0]).convert('L') # (H, W, 1) + mask = mask.resize(img_wh, Image.BICUBIC) + mask = TF.to_tensor(mask)[0] + else: + mask = torch.ones_like(img[...,0], device=img.device) + all_fg_masks.append(mask) # (h, w) + all_images.append(img) + + all_c2w = torch.stack(all_c2w, dim=0) + + pts3d = read_points3d_binary(os.path.join(self.config.root_dir, 'sparse/0/points3D.bin')) + pts3d = torch.from_numpy(np.array([pts3d[k].xyz for k in pts3d])).float() + all_c2w, pts3d = normalize_poses(all_c2w, pts3d, up_est_method=self.config.up_est_method, center_est_method=self.config.center_est_method) + + ColmapDatasetBase.properties = { + 'w': w, + 'h': h, + 'img_wh': img_wh, + 'factor': factor, + 'has_mask': has_mask, + 'apply_mask': apply_mask, + 'directions': directions, + 'pts3d': pts3d, + 'all_c2w': all_c2w, + 'all_images': all_images, + 'all_fg_masks': all_fg_masks + } + + ColmapDatasetBase.initialized = True + + for k, v in ColmapDatasetBase.properties.items(): + setattr(self, k, v) + + if self.split == 'test': + self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps) + self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32) + self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32) + else: + self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0).float(), torch.stack(self.all_fg_masks, dim=0).float() + + """ + # for debug use + from models.ray_utils import get_rays + rays_o, rays_d = get_rays(self.directions.cpu(), self.all_c2w, keepdim=True) + pts_out = [] + pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 0.0 0.0' for l in rays_o[:,0,0].reshape(-1, 3).tolist()])) + + t_vals = torch.linspace(0, 1, 8) + z_vals = 0.05 * (1 - t_vals) + 0.5 * t_vals + + ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,0][..., None, :]) + pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 0.0' for l in ray_pts.view(-1, 3).tolist()])) + + ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,0][..., None, :]) + pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) + + ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,self.w-1][..., None, :]) + pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) + + ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,self.w-1][..., None, :]) + pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) + + open('cameras.txt', 'w').write('\n'.join(pts_out)) + open('scene.txt', 'w').write('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 0.0' for l in self.pts3d.view(-1, 3).tolist()])) + + exit(1) + """ + + self.all_c2w = self.all_c2w.float().to(self.rank) + if self.config.load_data_on_gpu: + self.all_images = self.all_images.to(self.rank) + self.all_fg_masks = self.all_fg_masks.to(self.rank) + + +class ColmapDataset(Dataset, ColmapDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __len__(self): + return len(self.all_images) + + def __getitem__(self, index): + return { + 'index': index + } + + +class ColmapIterableDataset(IterableDataset, ColmapDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __iter__(self): + while True: + yield {} + + +@datasets.register('colmap') +class ColmapDataModule(pl.LightningDataModule): + def __init__(self, config): + super().__init__() + self.config = config + + def setup(self, stage=None): + if stage in [None, 'fit']: + self.train_dataset = ColmapIterableDataset(self.config, 'train') + if stage in [None, 'fit', 'validate']: + self.val_dataset = ColmapDataset(self.config, self.config.get('val_split', 'train')) + if stage in [None, 'test']: + self.test_dataset = ColmapDataset(self.config, self.config.get('test_split', 'test')) + if stage in [None, 'predict']: + self.predict_dataset = ColmapDataset(self.config, 'train') + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size): + sampler = None + return DataLoader( + dataset, + num_workers=os.cpu_count(), + batch_size=batch_size, + pin_memory=True, + sampler=sampler + ) + + def train_dataloader(self): + return self.general_loader(self.train_dataset, batch_size=1) + + def val_dataloader(self): + return self.general_loader(self.val_dataset, batch_size=1) + + def test_dataloader(self): + return self.general_loader(self.test_dataset, batch_size=1) + + def predict_dataloader(self): + return self.general_loader(self.predict_dataset, batch_size=1) diff --git a/instant-nsr-pl/datasets/colmap_utils.py b/instant-nsr-pl/datasets/colmap_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5064d53fc4b3a738fc8ab6e52c7a5fee853d16 --- /dev/null +++ b/instant-nsr-pl/datasets/colmap_utils.py @@ -0,0 +1,295 @@ +# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch at inf.ethz.ch) + +import os +import collections +import numpy as np +import struct + + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"]) +Camera = collections.namedtuple( + "Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + + +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) +} +CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \ + for camera_model in CAMERA_MODELS]) + + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + + +def read_cameras_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera(id=camera_id, model=model, + width=width, height=height, + params=params) + return cameras + + +def read_cameras_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for camera_line_index in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ") + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes(fid, num_bytes=8*num_params, + format_char_sequence="d"*num_params) + cameras[camera_id] = Camera(id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params)) + assert len(cameras) == num_cameras + return cameras + + +def read_images_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack([tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3]))]) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_images_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for image_index in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, + format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, + format_char_sequence="ddq"*num_points2D) + xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3]))]) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_points3D_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + points3D = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + point3D_id = int(elems[0]) + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = float(elems[7]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point2D_idxs = np.array(tuple(map(int, elems[9::2]))) + points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, + error=error, image_ids=image_ids, + point2D_idxs=point2D_idxs) + return points3D + + +def read_points3d_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for point_line_index in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd") + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q")[0] + track_elems = read_next_bytes( + fid, num_bytes=8*track_length, + format_char_sequence="ii"*track_length) + image_ids = np.array(tuple(map(int, track_elems[0::2]))) + point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, xyz=xyz, rgb=rgb, + error=error, image_ids=image_ids, + point2D_idxs=point2D_idxs) + return points3D + + +def read_model(path, ext): + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_images_binary(os.path.join(path, "images" + ext)) + points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array([ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec diff --git a/instant-nsr-pl/datasets/dtu.py b/instant-nsr-pl/datasets/dtu.py new file mode 100644 index 0000000000000000000000000000000000000000..39e3a36c54e95ca436ca99cc1e4d94d291c52b11 --- /dev/null +++ b/instant-nsr-pl/datasets/dtu.py @@ -0,0 +1,201 @@ +import os +import json +import math +import numpy as np +from PIL import Image +import cv2 + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader, IterableDataset +import torchvision.transforms.functional as TF + +import pytorch_lightning as pl + +import datasets +from models.ray_utils import get_ray_directions +from utils.misc import get_rank + + +def load_K_Rt_from_P(P=None): + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose + +def create_spheric_poses(cameras, n_steps=120): + center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device) + cam_center = F.normalize(cameras.mean(0), p=2, dim=-1) * cameras.mean(0).norm(2) + eigvecs = torch.linalg.eig(cameras.T @ cameras).eigenvectors + rot_axis = F.normalize(eigvecs[:,1].real.float(), p=2, dim=-1) + up = rot_axis + rot_dir = torch.cross(rot_axis, cam_center) + max_angle = (F.normalize(cameras, p=2, dim=-1) * F.normalize(cam_center, p=2, dim=-1)).sum(-1).acos().max() + + all_c2w = [] + for theta in torch.linspace(-max_angle, max_angle, n_steps): + cam_pos = cam_center * math.cos(theta) + rot_dir * math.sin(theta) + l = F.normalize(center - cam_pos, p=2, dim=0) + s = F.normalize(l.cross(up), p=2, dim=0) + u = F.normalize(s.cross(l), p=2, dim=0) + c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1) + all_c2w.append(c2w) + + all_c2w = torch.stack(all_c2w, dim=0) + + return all_c2w + +class DTUDatasetBase(): + def setup(self, config, split): + self.config = config + self.split = split + self.rank = get_rank() + + cams = np.load(os.path.join(self.config.root_dir, self.config.cameras_file)) + + img_sample = cv2.imread(os.path.join(self.config.root_dir, 'image', '000000.png')) + H, W = img_sample.shape[0], img_sample.shape[1] + + if 'img_wh' in self.config: + w, h = self.config.img_wh + assert round(W / w * h) == H + elif 'img_downscale' in self.config: + w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5) + else: + raise KeyError("Either img_wh or img_downscale should be specified.") + + self.w, self.h = w, h + self.img_wh = (w, h) + self.factor = w / W + + mask_dir = os.path.join(self.config.root_dir, 'mask') + self.has_mask = True + self.apply_mask = self.config.apply_mask + + self.directions = [] + self.all_c2w, self.all_images, self.all_fg_masks = [], [], [] + + n_images = max([int(k.split('_')[-1]) for k in cams.keys()]) + 1 + + for i in range(n_images): + world_mat, scale_mat = cams[f'world_mat_{i}'], cams[f'scale_mat_{i}'] + P = (world_mat @ scale_mat)[:3,:4] + K, c2w = load_K_Rt_from_P(P) + fx, fy, cx, cy = K[0,0] * self.factor, K[1,1] * self.factor, K[0,2] * self.factor, K[1,2] * self.factor + directions = get_ray_directions(w, h, fx, fy, cx, cy) + self.directions.append(directions) + + c2w = torch.from_numpy(c2w).float() + + # blender follows opengl camera coordinates (right up back) + # NeuS DTU data coordinate system (right down front) is different from blender + # https://github.com/Totoro97/NeuS/issues/9 + # for c2w, flip the sign of input camera coordinate yz + c2w_ = c2w.clone() + c2w_[:3,1:3] *= -1. # flip input sign + self.all_c2w.append(c2w_[:3,:4]) + + if self.split in ['train', 'val']: + img_path = os.path.join(self.config.root_dir, 'image', f'{i:06d}.png') + img = Image.open(img_path) + img = img.resize(self.img_wh, Image.BICUBIC) + img = TF.to_tensor(img).permute(1, 2, 0)[...,:3] + + mask_path = os.path.join(mask_dir, f'{i:03d}.png') + mask = Image.open(mask_path).convert('L') # (H, W, 1) + mask = mask.resize(self.img_wh, Image.BICUBIC) + mask = TF.to_tensor(mask)[0] + + self.all_fg_masks.append(mask) # (h, w) + self.all_images.append(img) + + self.all_c2w = torch.stack(self.all_c2w, dim=0) + + if self.split == 'test': + self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps) + self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32) + self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32) + self.directions = self.directions[0] + else: + self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0), torch.stack(self.all_fg_masks, dim=0) + self.directions = torch.stack(self.directions, dim=0) + + self.directions = self.directions.float().to(self.rank) + self.all_c2w, self.all_images, self.all_fg_masks = \ + self.all_c2w.float().to(self.rank), \ + self.all_images.float().to(self.rank), \ + self.all_fg_masks.float().to(self.rank) + + +class DTUDataset(Dataset, DTUDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __len__(self): + return len(self.all_images) + + def __getitem__(self, index): + return { + 'index': index + } + + +class DTUIterableDataset(IterableDataset, DTUDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __iter__(self): + while True: + yield {} + + +@datasets.register('dtu') +class DTUDataModule(pl.LightningDataModule): + def __init__(self, config): + super().__init__() + self.config = config + + def setup(self, stage=None): + if stage in [None, 'fit']: + self.train_dataset = DTUIterableDataset(self.config, 'train') + if stage in [None, 'fit', 'validate']: + self.val_dataset = DTUDataset(self.config, self.config.get('val_split', 'train')) + if stage in [None, 'test']: + self.test_dataset = DTUDataset(self.config, self.config.get('test_split', 'test')) + if stage in [None, 'predict']: + self.predict_dataset = DTUDataset(self.config, 'train') + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size): + sampler = None + return DataLoader( + dataset, + num_workers=os.cpu_count(), + batch_size=batch_size, + pin_memory=True, + sampler=sampler + ) + + def train_dataloader(self): + return self.general_loader(self.train_dataset, batch_size=1) + + def val_dataloader(self): + return self.general_loader(self.val_dataset, batch_size=1) + + def test_dataloader(self): + return self.general_loader(self.test_dataset, batch_size=1) + + def predict_dataloader(self): + return self.general_loader(self.predict_dataset, batch_size=1) diff --git a/instant-nsr-pl/datasets/fixed_poses/000_back_RT.txt b/instant-nsr-pl/datasets/fixed_poses/000_back_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..0b839ed2505438786e2d33bd779b77ed1eedb778 --- /dev/null +++ b/instant-nsr-pl/datasets/fixed_poses/000_back_RT.txt @@ -0,0 +1,3 @@ +-1.000000238418579102e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 +0.000000000000000000e+00 -1.343588564850506373e-07 1.000000119209289551e+00 1.746665105883948854e-07 +0.000000000000000000e+00 1.000000119209289551e+00 -1.343588564850506373e-07 -1.300000071525573730e+00 diff --git a/instant-nsr-pl/datasets/fixed_poses/000_back_left_RT.txt b/instant-nsr-pl/datasets/fixed_poses/000_back_left_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..97b10e711b1a86782cb69798051df209e8943b19 --- /dev/null +++ b/instant-nsr-pl/datasets/fixed_poses/000_back_left_RT.txt @@ -0,0 +1,3 @@ +-7.071069478988647461e-01 -7.071068286895751953e-01 0.000000000000000000e+00 -1.192092895507812500e-07 +0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 9.863901340168013121e-08 +-7.071068286895751953e-01 7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00 diff --git a/instant-nsr-pl/datasets/fixed_poses/000_back_right_RT.txt b/instant-nsr-pl/datasets/fixed_poses/000_back_right_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..0c7ce665f9ee958fe56e1589f52e4e772f3069e1 --- /dev/null +++ b/instant-nsr-pl/datasets/fixed_poses/000_back_right_RT.txt @@ -0,0 +1,3 @@ +-7.071069478988647461e-01 7.071068286895751953e-01 0.000000000000000000e+00 1.192092895507812500e-07 +0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 9.863901340168013121e-08 +7.071068286895751953e-01 7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00 diff --git a/instant-nsr-pl/datasets/fixed_poses/000_front_RT.txt b/instant-nsr-pl/datasets/fixed_poses/000_front_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..67db8bce2207aabc0b8fcf9db25a0af8b9dd9e7b --- /dev/null +++ b/instant-nsr-pl/datasets/fixed_poses/000_front_RT.txt @@ -0,0 +1,3 @@ +1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 +0.000000000000000000e+00 -1.343588564850506373e-07 1.000000119209289551e+00 -1.746665105883948854e-07 +0.000000000000000000e+00 -1.000000119209289551e+00 -1.343588564850506373e-07 -1.300000071525573730e+00 diff --git a/instant-nsr-pl/datasets/fixed_poses/000_front_left_RT.txt b/instant-nsr-pl/datasets/fixed_poses/000_front_left_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..bed4b8cf8913b5fbf1ec092bceea4da0e4014133 --- /dev/null +++ b/instant-nsr-pl/datasets/fixed_poses/000_front_left_RT.txt @@ -0,0 +1,3 @@ +7.071067690849304199e-01 -7.071068286895751953e-01 0.000000000000000000e+00 -1.192092895507812500e-07 +0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 -9.863901340168013121e-08 +-7.071068286895751953e-01 -7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00 diff --git a/instant-nsr-pl/datasets/fixed_poses/000_front_right_RT.txt b/instant-nsr-pl/datasets/fixed_poses/000_front_right_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..56064b9ddb2afa5ae1db28cd70a93018c1f59c33 --- /dev/null +++ b/instant-nsr-pl/datasets/fixed_poses/000_front_right_RT.txt @@ -0,0 +1,3 @@ +7.071067690849304199e-01 7.071068286895751953e-01 0.000000000000000000e+00 1.192092895507812500e-07 +0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 -9.863901340168013121e-08 +7.071068286895751953e-01 -7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00 diff --git a/instant-nsr-pl/datasets/fixed_poses/000_left_RT.txt b/instant-nsr-pl/datasets/fixed_poses/000_left_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..465ebaee41f28ba09c6e44451a9c200d4c23bf95 --- /dev/null +++ b/instant-nsr-pl/datasets/fixed_poses/000_left_RT.txt @@ -0,0 +1,3 @@ +-2.220446049250313081e-16 -1.000000000000000000e+00 0.000000000000000000e+00 -2.886579758146288598e-16 +0.000000000000000000e+00 -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00 +-1.000000000000000000e+00 0.000000000000000000e+00 -2.220446049250313081e-16 -1.299999952316284180e+00 diff --git a/instant-nsr-pl/datasets/fixed_poses/000_right_RT.txt b/instant-nsr-pl/datasets/fixed_poses/000_right_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..2a0c740f885267b285a6585ad4058536205181c5 --- /dev/null +++ b/instant-nsr-pl/datasets/fixed_poses/000_right_RT.txt @@ -0,0 +1,3 @@ +-2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00 2.886579758146288598e-16 +0.000000000000000000e+00 -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00 +1.000000000000000000e+00 0.000000000000000000e+00 -2.220446049250313081e-16 -1.299999952316284180e+00 diff --git a/instant-nsr-pl/datasets/fixed_poses/000_top_RT.txt b/instant-nsr-pl/datasets/fixed_poses/000_top_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..eba7ea36b7d091f390bae16d1428b52b5287bef0 --- /dev/null +++ b/instant-nsr-pl/datasets/fixed_poses/000_top_RT.txt @@ -0,0 +1,3 @@ +1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 +0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 +0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 -1.299999952316284180e+00 diff --git a/instant-nsr-pl/datasets/ortho.py b/instant-nsr-pl/datasets/ortho.py new file mode 100644 index 0000000000000000000000000000000000000000..b29664e1ebda5baf64e57d56e21250cf4a7692ba --- /dev/null +++ b/instant-nsr-pl/datasets/ortho.py @@ -0,0 +1,287 @@ +import os +import json +import math +import numpy as np +from PIL import Image +import cv2 + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader, IterableDataset +import torchvision.transforms.functional as TF + +import pytorch_lightning as pl + +import datasets +from models.ray_utils import get_ortho_ray_directions_origins, get_ortho_rays, get_ray_directions +from utils.misc import get_rank + +from glob import glob +import PIL.Image + + +def camNormal2worldNormal(rot_c2w, camNormal): + H,W,_ = camNormal.shape + normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) + + return normal_img + +def worldNormal2camNormal(rot_w2c, worldNormal): + H,W,_ = worldNormal.shape + normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) + + return normal_img + +def trans_normal(normal, RT_w2c, RT_w2c_target): + + normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal) + normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world) + + return normal_target_cam + +def img2normal(img): + return (img/255.)*2-1 + +def normal2img(normal): + return np.uint8((normal*0.5+0.5)*255) + +def norm_normalize(normal, dim=-1): + + normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6) + + return normal + +def RT_opengl2opencv(RT): + # Build the coordinate transform matrix from world to computer vision camera + # R_world2cv = R_bcam2cv@R_world2bcam + # T_world2cv = R_bcam2cv@T_world2bcam + + R = RT[:3, :3] + t = RT[:3, 3] + + R_bcam2cv = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32) + + R_world2cv = R_bcam2cv @ R + t_world2cv = R_bcam2cv @ t + + RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1) + + return RT + +def normal_opengl2opencv(normal): + H,W,C = np.shape(normal) + # normal_img = np.reshape(normal, (H*W,C)) + R_bcam2cv = np.array([1, -1, -1], np.float32) + normal_cv = normal * R_bcam2cv[None, None, :] + + print(np.shape(normal_cv)) + + return normal_cv + +def inv_RT(RT): + RT_h = np.concatenate([RT, np.array([[0,0,0,1]])], axis=0) + RT_inv = np.linalg.inv(RT_h) + + return RT_inv[:3, :] + + +def load_a_prediction(root_dir, test_object, imSize, view_types, load_color=False, cam_pose_dir=None, + normal_system='front', erode_mask=True, camera_type='ortho', cam_params=None): + + all_images = [] + all_normals = [] + all_normals_world = [] + all_masks = [] + all_color_masks = [] + all_poses = [] + all_w2cs = [] + directions = [] + ray_origins = [] + + RT_front = np.loadtxt(glob(os.path.join(cam_pose_dir, '*_%s_RT.txt'%( 'front')))[0]) # world2cam matrix + RT_front_cv = RT_opengl2opencv(RT_front) # convert normal from opengl to opencv + for idx, view in enumerate(view_types): + print(os.path.join(root_dir,test_object)) + normal_filepath = os.path.join(root_dir, test_object, 'normals_000_%s.png'%( view)) + # Load key frame + if load_color: # use bgr + image =np.array(PIL.Image.open(normal_filepath.replace("normals", "rgb")).resize(imSize))[:, :, :3] + + normal = np.array(PIL.Image.open(normal_filepath).resize(imSize)) + mask = normal[:, :, 3] + normal = normal[:, :, :3] + + color_mask = np.array(PIL.Image.open(os.path.join(root_dir,test_object, 'masked_colors/rgb_000_%s.png'%( view))).resize(imSize))[:, :, 3] + invalid_color_mask = color_mask < 255*0.5 + threshold = np.ones_like(image[:, :, 0]) * 250 + invalid_white_mask = (image[:, :, 0] > threshold) & (image[:, :, 1] > threshold) & (image[:, :, 2] > threshold) + invalid_color_mask_final = invalid_color_mask & invalid_white_mask + color_mask = (1 - invalid_color_mask_final) > 0 + + # if erode_mask: + # kernel = np.ones((3, 3), np.uint8) + # mask = cv2.erode(mask, kernel, iterations=1) + + RT = np.loadtxt(os.path.join(cam_pose_dir, '000_%s_RT.txt'%( view))) # world2cam matrix + + normal = img2normal(normal) + + normal[mask==0] = [0,0,0] + mask = mask> (0.5*255) + if load_color: + all_images.append(image) + + all_masks.append(mask) + all_color_masks.append(color_mask) + RT_cv = RT_opengl2opencv(RT) # convert normal from opengl to opencv + all_poses.append(inv_RT(RT_cv)) # cam2world + all_w2cs.append(RT_cv) + + # whether to + normal_cam_cv = normal_opengl2opencv(normal) + + if normal_system == 'front': + print("the loaded normals are defined in the system of front view") + normal_world = camNormal2worldNormal(inv_RT(RT_front_cv)[:3, :3], normal_cam_cv) + elif normal_system == 'self': + print("the loaded normals are in their independent camera systems") + normal_world = camNormal2worldNormal(inv_RT(RT_cv)[:3, :3], normal_cam_cv) + all_normals.append(normal_cam_cv) + all_normals_world.append(normal_world) + + if camera_type == 'ortho': + origins, dirs = get_ortho_ray_directions_origins(W=imSize[0], H=imSize[1]) + elif camera_type == 'pinhole': + dirs = get_ray_directions(W=imSize[0], H=imSize[1], + fx=cam_params[0], fy=cam_params[1], cx=cam_params[2], cy=cam_params[3]) + origins = dirs # occupy a position + else: + raise Exception("not support camera type") + ray_origins.append(origins) + directions.append(dirs) + + + if not load_color: + all_images = [normal2img(x) for x in all_normals_world] + + + return np.stack(all_images), np.stack(all_masks), np.stack(all_normals), \ + np.stack(all_normals_world), np.stack(all_poses), np.stack(all_w2cs), np.stack(ray_origins), np.stack(directions), np.stack(all_color_masks) + + +class OrthoDatasetBase(): + def setup(self, config, split): + self.config = config + self.split = split + self.rank = get_rank() + + self.data_dir = self.config.root_dir + self.object_name = self.config.scene + self.scene = self.config.scene + self.imSize = self.config.imSize + self.load_color = True + self.img_wh = [self.imSize[0], self.imSize[1]] + self.w = self.img_wh[0] + self.h = self.img_wh[1] + self.camera_type = self.config.camera_type + self.camera_params = self.config.camera_params # [fx, fy, cx, cy] + + self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left'] + + self.view_weights = torch.from_numpy(np.array(self.config.view_weights)).float().to(self.rank).view(-1) + self.view_weights = self.view_weights.view(-1,1,1).repeat(1, self.h, self.w) + + if self.config.cam_pose_dir is None: + self.cam_pose_dir = "./datasets/fixed_poses" + else: + self.cam_pose_dir = self.config.cam_pose_dir + + self.images_np, self.masks_np, self.normals_cam_np, self.normals_world_np, \ + self.pose_all_np, self.w2c_all_np, self.origins_np, self.directions_np, self.rgb_masks_np = load_a_prediction( + self.data_dir, self.object_name, self.imSize, self.view_types, + self.load_color, self.cam_pose_dir, normal_system='front', + camera_type=self.camera_type, cam_params=self.camera_params) + + self.has_mask = True + self.apply_mask = self.config.apply_mask + + self.all_c2w = torch.from_numpy(self.pose_all_np) + self.all_images = torch.from_numpy(self.images_np) / 255. + self.all_fg_masks = torch.from_numpy(self.masks_np) + self.all_rgb_masks = torch.from_numpy(self.rgb_masks_np) + self.all_normals_world = torch.from_numpy(self.normals_world_np) + self.origins = torch.from_numpy(self.origins_np) + self.directions = torch.from_numpy(self.directions_np) + + self.directions = self.directions.float().to(self.rank) + self.origins = self.origins.float().to(self.rank) + self.all_rgb_masks = self.all_rgb_masks.float().to(self.rank) + self.all_c2w, self.all_images, self.all_fg_masks, self.all_normals_world = \ + self.all_c2w.float().to(self.rank), \ + self.all_images.float().to(self.rank), \ + self.all_fg_masks.float().to(self.rank), \ + self.all_normals_world.float().to(self.rank) + + +class OrthoDataset(Dataset, OrthoDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __len__(self): + return len(self.all_images) + + def __getitem__(self, index): + return { + 'index': index + } + + +class OrthoIterableDataset(IterableDataset, OrthoDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __iter__(self): + while True: + yield {} + + +@datasets.register('ortho') +class OrthoDataModule(pl.LightningDataModule): + def __init__(self, config): + super().__init__() + self.config = config + + def setup(self, stage=None): + if stage in [None, 'fit']: + self.train_dataset = OrthoIterableDataset(self.config, 'train') + if stage in [None, 'fit', 'validate']: + self.val_dataset = OrthoDataset(self.config, self.config.get('val_split', 'train')) + if stage in [None, 'test']: + self.test_dataset = OrthoDataset(self.config, self.config.get('test_split', 'test')) + if stage in [None, 'predict']: + self.predict_dataset = OrthoDataset(self.config, 'train') + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size): + sampler = None + return DataLoader( + dataset, + num_workers=os.cpu_count(), + batch_size=batch_size, + pin_memory=True, + sampler=sampler + ) + + def train_dataloader(self): + return self.general_loader(self.train_dataset, batch_size=1) + + def val_dataloader(self): + return self.general_loader(self.val_dataset, batch_size=1) + + def test_dataloader(self): + return self.general_loader(self.test_dataset, batch_size=1) + + def predict_dataloader(self): + return self.general_loader(self.predict_dataset, batch_size=1) diff --git a/instant-nsr-pl/datasets/utils.py b/instant-nsr-pl/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/instant-nsr-pl/launch.py b/instant-nsr-pl/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..50cc6bea0e9627a819ea74bde50f5f8707da957c --- /dev/null +++ b/instant-nsr-pl/launch.py @@ -0,0 +1,125 @@ +import sys +import argparse +import os +import time +import logging +from datetime import datetime + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', required=True, help='path to config file') + parser.add_argument('--gpu', default='0', help='GPU(s) to be used') + parser.add_argument('--resume', default=None, help='path to the weights to be resumed') + parser.add_argument( + '--resume_weights_only', + action='store_true', + help='specify this argument to restore only the weights (w/o training states), e.g. --resume path/to/resume --resume_weights_only' + ) + + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument('--train', action='store_true') + group.add_argument('--validate', action='store_true') + group.add_argument('--test', action='store_true') + group.add_argument('--predict', action='store_true') + # group.add_argument('--export', action='store_true') # TODO: a separate export action + + parser.add_argument('--exp_dir', default='./exp') + parser.add_argument('--runs_dir', default='./runs') + parser.add_argument('--verbose', action='store_true', help='if true, set logging level to DEBUG') + + args, extras = parser.parse_known_args() + + # set CUDA_VISIBLE_DEVICES then import pytorch-lightning + os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' + os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + n_gpus = len(args.gpu.split(',')) + + import datasets + import systems + import pytorch_lightning as pl + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor + from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger + from utils.callbacks import CodeSnapshotCallback, ConfigSnapshotCallback, CustomProgressBar + from utils.misc import load_config + + # parse YAML config to OmegaConf + config = load_config(args.config, cli_args=extras) + config.cmd_args = vars(args) + + config.trial_name = config.get('trial_name') or (config.tag + datetime.now().strftime('@%Y%m%d-%H%M%S')) + config.exp_dir = config.get('exp_dir') or os.path.join(args.exp_dir, config.name) + config.save_dir = config.get('save_dir') or os.path.join(config.exp_dir, config.trial_name, 'save') + config.ckpt_dir = config.get('ckpt_dir') or os.path.join(config.exp_dir, config.trial_name, 'ckpt') + config.code_dir = config.get('code_dir') or os.path.join(config.exp_dir, config.trial_name, 'code') + config.config_dir = config.get('config_dir') or os.path.join(config.exp_dir, config.trial_name, 'config') + + logger = logging.getLogger('pytorch_lightning') + if args.verbose: + logger.setLevel(logging.DEBUG) + + if 'seed' not in config: + config.seed = int(time.time() * 1000) % 1000 + pl.seed_everything(config.seed) + + dm = datasets.make(config.dataset.name, config.dataset) + system = systems.make(config.system.name, config, load_from_checkpoint=None if not args.resume_weights_only else args.resume) + + callbacks = [] + if args.train: + callbacks += [ + ModelCheckpoint( + dirpath=config.ckpt_dir, + **config.checkpoint + ), + LearningRateMonitor(logging_interval='step'), + CodeSnapshotCallback( + config.code_dir, use_version=False + ), + ConfigSnapshotCallback( + config, config.config_dir, use_version=False + ), + CustomProgressBar(refresh_rate=1), + ] + + loggers = [] + if args.train: + loggers += [ + TensorBoardLogger(args.runs_dir, name=config.name, version=config.trial_name), + CSVLogger(config.exp_dir, name=config.trial_name, version='csv_logs') + ] + + if sys.platform == 'win32': + # does not support multi-gpu on windows + strategy = 'dp' + assert n_gpus == 1 + else: + strategy = 'ddp_find_unused_parameters_false' + + trainer = Trainer( + devices=n_gpus, + accelerator='gpu', + callbacks=callbacks, + logger=loggers, + strategy=strategy, + **config.trainer + ) + + if args.train: + if args.resume and not args.resume_weights_only: + # FIXME: different behavior in pytorch-lighting>1.9 ? + trainer.fit(system, datamodule=dm, ckpt_path=args.resume) + else: + trainer.fit(system, datamodule=dm) + trainer.test(system, datamodule=dm) + elif args.validate: + trainer.validate(system, datamodule=dm, ckpt_path=args.resume) + elif args.test: + trainer.test(system, datamodule=dm, ckpt_path=args.resume) + elif args.predict: + trainer.predict(system, datamodule=dm, ckpt_path=args.resume) + + +if __name__ == '__main__': + main() diff --git a/instant-nsr-pl/models/__init__.py b/instant-nsr-pl/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c0464e9d6f4706af1829aa0e18c8ecd89203baff --- /dev/null +++ b/instant-nsr-pl/models/__init__.py @@ -0,0 +1,16 @@ +models = {} + + +def register(name): + def decorator(cls): + models[name] = cls + return cls + return decorator + + +def make(name, config): + model = models[name](config) + return model + + +from . import nerf, neus, geometry, texture diff --git a/instant-nsr-pl/models/base.py b/instant-nsr-pl/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..47b853bc9502ff2581639b5a6bc7313ffe0ec9ec --- /dev/null +++ b/instant-nsr-pl/models/base.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + +from utils.misc import get_rank + +class BaseModel(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.rank = get_rank() + self.setup() + if self.config.get('weights', None): + self.load_state_dict(torch.load(self.config.weights)) + + def setup(self): + raise NotImplementedError + + def update_step(self, epoch, global_step): + pass + + def train(self, mode=True): + return super().train(mode=mode) + + def eval(self): + return super().eval() + + def regularizations(self, out): + return {} + + @torch.no_grad() + def export(self, export_config): + return {} diff --git a/instant-nsr-pl/models/geometry.py b/instant-nsr-pl/models/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..861edbe2726bb19e7c705c419837f369de170f28 --- /dev/null +++ b/instant-nsr-pl/models/geometry.py @@ -0,0 +1,238 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pytorch_lightning.utilities.rank_zero import rank_zero_info + +import models +from models.base import BaseModel +from models.utils import scale_anything, get_activation, cleanup, chunk_batch +from models.network_utils import get_encoding, get_mlp, get_encoding_with_network +from utils.misc import get_rank +from systems.utils import update_module_step +from nerfacc import ContractionType + + +def contract_to_unisphere(x, radius, contraction_type): + if contraction_type == ContractionType.AABB: + x = scale_anything(x, (-radius, radius), (0, 1)) + elif contraction_type == ContractionType.UN_BOUNDED_SPHERE: + x = scale_anything(x, (-radius, radius), (0, 1)) + x = x * 2 - 1 # aabb is at [-1, 1] + mag = x.norm(dim=-1, keepdim=True) + mask = mag.squeeze(-1) > 1 + x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask]) + x = x / 4 + 0.5 # [-inf, inf] is at [0, 1] + else: + raise NotImplementedError + return x + + +class MarchingCubeHelper(nn.Module): + def __init__(self, resolution, use_torch=True): + super().__init__() + self.resolution = resolution + self.use_torch = use_torch + self.points_range = (0, 1) + if self.use_torch: + import torchmcubes + self.mc_func = torchmcubes.marching_cubes + else: + import mcubes + self.mc_func = mcubes.marching_cubes + self.verts = None + + def grid_vertices(self): + if self.verts is None: + x, y, z = torch.linspace(*self.points_range, self.resolution), torch.linspace(*self.points_range, self.resolution), torch.linspace(*self.points_range, self.resolution) + x, y, z = torch.meshgrid(x, y, z, indexing='ij') + verts = torch.cat([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1).reshape(-1, 3) + self.verts = verts + return self.verts + + def forward(self, level, threshold=0.): + level = level.float().view(self.resolution, self.resolution, self.resolution) + if self.use_torch: + verts, faces = self.mc_func(level.to(get_rank()), threshold) + verts, faces = verts.cpu(), faces.cpu().long() + else: + verts, faces = self.mc_func(-level.numpy(), threshold) # transform to numpy + verts, faces = torch.from_numpy(verts.astype(np.float32)), torch.from_numpy(faces.astype(np.int64)) # transform back to pytorch + verts = verts / (self.resolution - 1.) + return { + 'v_pos': verts, + 't_pos_idx': faces + } + + +class BaseImplicitGeometry(BaseModel): + def __init__(self, config): + super().__init__(config) + if self.config.isosurface is not None: + assert self.config.isosurface.method in ['mc', 'mc-torch'] + if self.config.isosurface.method == 'mc-torch': + raise NotImplementedError("Please do not use mc-torch. It currently has some scaling issues I haven't fixed yet.") + self.helper = MarchingCubeHelper(self.config.isosurface.resolution, use_torch=self.config.isosurface.method=='mc-torch') + self.radius = self.config.radius + self.contraction_type = None # assigned in system + + def forward_level(self, points): + raise NotImplementedError + + def isosurface_(self, vmin, vmax): + def batch_func(x): + x = torch.stack([ + scale_anything(x[...,0], (0, 1), (vmin[0], vmax[0])), + scale_anything(x[...,1], (0, 1), (vmin[1], vmax[1])), + scale_anything(x[...,2], (0, 1), (vmin[2], vmax[2])), + ], dim=-1).to(self.rank) + rv = self.forward_level(x).cpu() + cleanup() + return rv + + level = chunk_batch(batch_func, self.config.isosurface.chunk, True, self.helper.grid_vertices()) + mesh = self.helper(level, threshold=self.config.isosurface.threshold) + mesh['v_pos'] = torch.stack([ + scale_anything(mesh['v_pos'][...,0], (0, 1), (vmin[0], vmax[0])), + scale_anything(mesh['v_pos'][...,1], (0, 1), (vmin[1], vmax[1])), + scale_anything(mesh['v_pos'][...,2], (0, 1), (vmin[2], vmax[2])) + ], dim=-1) + return mesh + + @torch.no_grad() + def isosurface(self): + if self.config.isosurface is None: + raise NotImplementedError + mesh_coarse = self.isosurface_((-self.radius, -self.radius, -self.radius), (self.radius, self.radius, self.radius)) + vmin, vmax = mesh_coarse['v_pos'].amin(dim=0), mesh_coarse['v_pos'].amax(dim=0) + vmin_ = (vmin - (vmax - vmin) * 0.1).clamp(-self.radius, self.radius) + vmax_ = (vmax + (vmax - vmin) * 0.1).clamp(-self.radius, self.radius) + mesh_fine = self.isosurface_(vmin_, vmax_) + return mesh_fine + + +@models.register('volume-density') +class VolumeDensity(BaseImplicitGeometry): + def setup(self): + self.n_input_dims = self.config.get('n_input_dims', 3) + self.n_output_dims = self.config.feature_dim + self.encoding_with_network = get_encoding_with_network(self.n_input_dims, self.n_output_dims, self.config.xyz_encoding_config, self.config.mlp_network_config) + + def forward(self, points): + points = contract_to_unisphere(points, self.radius, self.contraction_type) + out = self.encoding_with_network(points.view(-1, self.n_input_dims)).view(*points.shape[:-1], self.n_output_dims).float() + density, feature = out[...,0], out + if 'density_activation' in self.config: + density = get_activation(self.config.density_activation)(density + float(self.config.density_bias)) + if 'feature_activation' in self.config: + feature = get_activation(self.config.feature_activation)(feature) + return density, feature + + def forward_level(self, points): + points = contract_to_unisphere(points, self.radius, self.contraction_type) + density = self.encoding_with_network(points.reshape(-1, self.n_input_dims)).reshape(*points.shape[:-1], self.n_output_dims)[...,0] + if 'density_activation' in self.config: + density = get_activation(self.config.density_activation)(density + float(self.config.density_bias)) + return -density + + def update_step(self, epoch, global_step): + update_module_step(self.encoding_with_network, epoch, global_step) + + +@models.register('volume-sdf') +class VolumeSDF(BaseImplicitGeometry): + def setup(self): + self.n_output_dims = self.config.feature_dim + encoding = get_encoding(3, self.config.xyz_encoding_config) + network = get_mlp(encoding.n_output_dims, self.n_output_dims, self.config.mlp_network_config) + self.encoding, self.network = encoding, network + self.grad_type = self.config.grad_type + self.finite_difference_eps = self.config.get('finite_difference_eps', 1e-3) + # the actual value used in training + # will update at certain steps if finite_difference_eps="progressive" + self._finite_difference_eps = None + if self.grad_type == 'finite_difference': + rank_zero_info(f"Using finite difference to compute gradients with eps={self.finite_difference_eps}") + + def forward(self, points, with_grad=True, with_feature=True, with_laplace=False): + with torch.inference_mode(torch.is_inference_mode_enabled() and not (with_grad and self.grad_type == 'analytic')): + with torch.set_grad_enabled(self.training or (with_grad and self.grad_type == 'analytic')): + if with_grad and self.grad_type == 'analytic': + if not self.training: + points = points.clone() # points may be in inference mode, get a copy to enable grad + points.requires_grad_(True) + + points_ = points # points in the original scale + points = contract_to_unisphere(points, self.radius, self.contraction_type) # points normalized to (0, 1) + + out = self.network(self.encoding(points.view(-1, 3))).view(*points.shape[:-1], self.n_output_dims).float() + sdf, feature = out[...,0], out + if 'sdf_activation' in self.config: + sdf = get_activation(self.config.sdf_activation)(sdf + float(self.config.sdf_bias)) + if 'feature_activation' in self.config: + feature = get_activation(self.config.feature_activation)(feature) + if with_grad: + if self.grad_type == 'analytic': + grad = torch.autograd.grad( + sdf, points_, grad_outputs=torch.ones_like(sdf), + create_graph=True, retain_graph=True, only_inputs=True + )[0] + elif self.grad_type == 'finite_difference': + eps = self._finite_difference_eps + offsets = torch.as_tensor( + [ + [eps, 0.0, 0.0], + [-eps, 0.0, 0.0], + [0.0, eps, 0.0], + [0.0, -eps, 0.0], + [0.0, 0.0, eps], + [0.0, 0.0, -eps], + ] + ).to(points_) + points_d_ = (points_[...,None,:] + offsets).clamp(-self.radius, self.radius) + points_d = scale_anything(points_d_, (-self.radius, self.radius), (0, 1)) + points_d_sdf = self.network(self.encoding(points_d.view(-1, 3)))[...,0].view(*points.shape[:-1], 6).float() + grad = 0.5 * (points_d_sdf[..., 0::2] - points_d_sdf[..., 1::2]) / eps + + if with_laplace: + laplace = (points_d_sdf[..., 0::2] + points_d_sdf[..., 1::2] - 2 * sdf[..., None]).sum(-1) / (eps ** 2) + + rv = [sdf] + if with_grad: + rv.append(grad) + if with_feature: + rv.append(feature) + if with_laplace: + assert self.config.grad_type == 'finite_difference', "Laplace computation is only supported with grad_type='finite_difference'" + rv.append(laplace) + rv = [v if self.training else v.detach() for v in rv] + return rv[0] if len(rv) == 1 else rv + + def forward_level(self, points): + points = contract_to_unisphere(points, self.radius, self.contraction_type) # points normalized to (0, 1) + sdf = self.network(self.encoding(points.view(-1, 3))).view(*points.shape[:-1], self.n_output_dims)[...,0] + if 'sdf_activation' in self.config: + sdf = get_activation(self.config.sdf_activation)(sdf + float(self.config.sdf_bias)) + return sdf + + def update_step(self, epoch, global_step): + update_module_step(self.encoding, epoch, global_step) + update_module_step(self.network, epoch, global_step) + if self.grad_type == 'finite_difference': + if isinstance(self.finite_difference_eps, float): + self._finite_difference_eps = self.finite_difference_eps + elif self.finite_difference_eps == 'progressive': + hg_conf = self.config.xyz_encoding_config + assert hg_conf.otype == "ProgressiveBandHashGrid", "finite_difference_eps='progressive' only works with ProgressiveBandHashGrid" + current_level = min( + hg_conf.start_level + max(global_step - hg_conf.start_step, 0) // hg_conf.update_steps, + hg_conf.n_levels + ) + grid_res = hg_conf.base_resolution * hg_conf.per_level_scale**(current_level - 1) + grid_size = 2 * self.config.radius / grid_res + if grid_size != self._finite_difference_eps: + rank_zero_info(f"Update finite_difference_eps to {grid_size}") + self._finite_difference_eps = grid_size + else: + raise ValueError(f"Unknown finite_difference_eps={self.finite_difference_eps}") diff --git a/instant-nsr-pl/models/nerf.py b/instant-nsr-pl/models/nerf.py new file mode 100644 index 0000000000000000000000000000000000000000..64ce5f9b839828eb02292faa4828108694f5f6d1 --- /dev/null +++ b/instant-nsr-pl/models/nerf.py @@ -0,0 +1,161 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import models +from models.base import BaseModel +from models.utils import chunk_batch +from systems.utils import update_module_step +from nerfacc import ContractionType, OccupancyGrid, ray_marching, render_weight_from_density, accumulate_along_rays + + +@models.register('nerf') +class NeRFModel(BaseModel): + def setup(self): + self.geometry = models.make(self.config.geometry.name, self.config.geometry) + self.texture = models.make(self.config.texture.name, self.config.texture) + self.register_buffer('scene_aabb', torch.as_tensor([-self.config.radius, -self.config.radius, -self.config.radius, self.config.radius, self.config.radius, self.config.radius], dtype=torch.float32)) + + if self.config.learned_background: + self.occupancy_grid_res = 256 + self.near_plane, self.far_plane = 0.2, 1e4 + self.cone_angle = 10**(math.log10(self.far_plane) / self.config.num_samples_per_ray) - 1. # approximate + self.render_step_size = 0.01 # render_step_size = max(distance_to_camera * self.cone_angle, self.render_step_size) + self.contraction_type = ContractionType.UN_BOUNDED_SPHERE + else: + self.occupancy_grid_res = 128 + self.near_plane, self.far_plane = None, None + self.cone_angle = 0.0 + self.render_step_size = 1.732 * 2 * self.config.radius / self.config.num_samples_per_ray + self.contraction_type = ContractionType.AABB + + self.geometry.contraction_type = self.contraction_type + + if self.config.grid_prune: + self.occupancy_grid = OccupancyGrid( + roi_aabb=self.scene_aabb, + resolution=self.occupancy_grid_res, + contraction_type=self.contraction_type + ) + self.randomized = self.config.randomized + self.background_color = None + + def update_step(self, epoch, global_step): + update_module_step(self.geometry, epoch, global_step) + update_module_step(self.texture, epoch, global_step) + + def occ_eval_fn(x): + density, _ = self.geometry(x) + # approximate for 1 - torch.exp(-density[...,None] * self.render_step_size) based on taylor series + return density[...,None] * self.render_step_size + + if self.training and self.config.grid_prune: + self.occupancy_grid.every_n_step(step=global_step, occ_eval_fn=occ_eval_fn) + + def isosurface(self): + mesh = self.geometry.isosurface() + return mesh + + def forward_(self, rays): + n_rays = rays.shape[0] + rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) + + def sigma_fn(t_starts, t_ends, ray_indices): + ray_indices = ray_indices.long() + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + positions = t_origins + t_dirs * (t_starts + t_ends) / 2. + density, _ = self.geometry(positions) + return density[...,None] + + def rgb_sigma_fn(t_starts, t_ends, ray_indices): + ray_indices = ray_indices.long() + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + positions = t_origins + t_dirs * (t_starts + t_ends) / 2. + density, feature = self.geometry(positions) + rgb = self.texture(feature, t_dirs) + return rgb, density[...,None] + + with torch.no_grad(): + ray_indices, t_starts, t_ends = ray_marching( + rays_o, rays_d, + scene_aabb=None if self.config.learned_background else self.scene_aabb, + grid=self.occupancy_grid if self.config.grid_prune else None, + sigma_fn=sigma_fn, + near_plane=self.near_plane, far_plane=self.far_plane, + render_step_size=self.render_step_size, + stratified=self.randomized, + cone_angle=self.cone_angle, + alpha_thre=0.0 + ) + + ray_indices = ray_indices.long() + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + midpoints = (t_starts + t_ends) / 2. + positions = t_origins + t_dirs * midpoints + intervals = t_ends - t_starts + + density, feature = self.geometry(positions) + rgb = self.texture(feature, t_dirs) + + weights = render_weight_from_density(t_starts, t_ends, density[...,None], ray_indices=ray_indices, n_rays=n_rays) + opacity = accumulate_along_rays(weights, ray_indices, values=None, n_rays=n_rays) + depth = accumulate_along_rays(weights, ray_indices, values=midpoints, n_rays=n_rays) + comp_rgb = accumulate_along_rays(weights, ray_indices, values=rgb, n_rays=n_rays) + comp_rgb = comp_rgb + self.background_color * (1.0 - opacity) + + out = { + 'comp_rgb': comp_rgb, + 'opacity': opacity, + 'depth': depth, + 'rays_valid': opacity > 0, + 'num_samples': torch.as_tensor([len(t_starts)], dtype=torch.int32, device=rays.device) + } + + if self.training: + out.update({ + 'weights': weights.view(-1), + 'points': midpoints.view(-1), + 'intervals': intervals.view(-1), + 'ray_indices': ray_indices.view(-1) + }) + + return out + + def forward(self, rays): + if self.training: + out = self.forward_(rays) + else: + out = chunk_batch(self.forward_, self.config.ray_chunk, True, rays) + return { + **out, + } + + def train(self, mode=True): + self.randomized = mode and self.config.randomized + return super().train(mode=mode) + + def eval(self): + self.randomized = False + return super().eval() + + def regularizations(self, out): + losses = {} + losses.update(self.geometry.regularizations(out)) + losses.update(self.texture.regularizations(out)) + return losses + + @torch.no_grad() + def export(self, export_config): + mesh = self.isosurface() + if export_config.export_vertex_color: + _, feature = chunk_batch(self.geometry, export_config.chunk_size, False, mesh['v_pos'].to(self.rank)) + viewdirs = torch.zeros(feature.shape[0], 3).to(feature) + viewdirs[...,2] = -1. # set the viewing directions to be -z (looking down) + rgb = self.texture(feature, viewdirs).clamp(0,1) + mesh['v_rgb'] = rgb.cpu() + return mesh diff --git a/instant-nsr-pl/models/network_utils.py b/instant-nsr-pl/models/network_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bf1c4ab64487b68118e62cbc834dc2f1ff908ad7 --- /dev/null +++ b/instant-nsr-pl/models/network_utils.py @@ -0,0 +1,215 @@ +import math +import numpy as np + +import torch +import torch.nn as nn +import tinycudann as tcnn + +from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info + +from utils.misc import config_to_primitive, get_rank +from models.utils import get_activation +from systems.utils import update_module_step + +class VanillaFrequency(nn.Module): + def __init__(self, in_channels, config): + super().__init__() + self.N_freqs = config['n_frequencies'] + self.in_channels, self.n_input_dims = in_channels, in_channels + self.funcs = [torch.sin, torch.cos] + self.freq_bands = 2**torch.linspace(0, self.N_freqs-1, self.N_freqs) + self.n_output_dims = self.in_channels * (len(self.funcs) * self.N_freqs) + self.n_masking_step = config.get('n_masking_step', 0) + self.update_step(None, None) # mask should be updated at the beginning each step + + def forward(self, x): + out = [] + for freq, mask in zip(self.freq_bands, self.mask): + for func in self.funcs: + out += [func(freq*x) * mask] + return torch.cat(out, -1) + + def update_step(self, epoch, global_step): + if self.n_masking_step <= 0 or global_step is None: + self.mask = torch.ones(self.N_freqs, dtype=torch.float32) + else: + self.mask = (1. - torch.cos(math.pi * (global_step / self.n_masking_step * self.N_freqs - torch.arange(0, self.N_freqs)).clamp(0, 1))) / 2. + rank_zero_debug(f'Update mask: {global_step}/{self.n_masking_step} {self.mask}') + + +class ProgressiveBandHashGrid(nn.Module): + def __init__(self, in_channels, config): + super().__init__() + self.n_input_dims = in_channels + encoding_config = config.copy() + encoding_config['otype'] = 'HashGrid' + with torch.cuda.device(get_rank()): + self.encoding = tcnn.Encoding(in_channels, encoding_config) + self.n_output_dims = self.encoding.n_output_dims + self.n_level = config['n_levels'] + self.n_features_per_level = config['n_features_per_level'] + self.start_level, self.start_step, self.update_steps = config['start_level'], config['start_step'], config['update_steps'] + self.current_level = self.start_level + self.mask = torch.zeros(self.n_level * self.n_features_per_level, dtype=torch.float32, device=get_rank()) + + def forward(self, x): + enc = self.encoding(x) + enc = enc * self.mask + return enc + + def update_step(self, epoch, global_step): + current_level = min(self.start_level + max(global_step - self.start_step, 0) // self.update_steps, self.n_level) + if current_level > self.current_level: + rank_zero_info(f'Update grid level to {current_level}') + self.current_level = current_level + self.mask[:self.current_level * self.n_features_per_level] = 1. + + +class CompositeEncoding(nn.Module): + def __init__(self, encoding, include_xyz=False, xyz_scale=1., xyz_offset=0.): + super(CompositeEncoding, self).__init__() + self.encoding = encoding + self.include_xyz, self.xyz_scale, self.xyz_offset = include_xyz, xyz_scale, xyz_offset + self.n_output_dims = int(self.include_xyz) * self.encoding.n_input_dims + self.encoding.n_output_dims + + def forward(self, x, *args): + return self.encoding(x, *args) if not self.include_xyz else torch.cat([x * self.xyz_scale + self.xyz_offset, self.encoding(x, *args)], dim=-1) + + def update_step(self, epoch, global_step): + update_module_step(self.encoding, epoch, global_step) + + +def get_encoding(n_input_dims, config): + # input suppose to be range [0, 1] + if config.otype == 'VanillaFrequency': + encoding = VanillaFrequency(n_input_dims, config_to_primitive(config)) + elif config.otype == 'ProgressiveBandHashGrid': + encoding = ProgressiveBandHashGrid(n_input_dims, config_to_primitive(config)) + else: + with torch.cuda.device(get_rank()): + encoding = tcnn.Encoding(n_input_dims, config_to_primitive(config)) + encoding = CompositeEncoding(encoding, include_xyz=config.get('include_xyz', False), xyz_scale=2., xyz_offset=-1.) + return encoding + + +class VanillaMLP(nn.Module): + def __init__(self, dim_in, dim_out, config): + super().__init__() + self.n_neurons, self.n_hidden_layers = config['n_neurons'], config['n_hidden_layers'] + self.sphere_init, self.weight_norm = config.get('sphere_init', False), config.get('weight_norm', False) + self.sphere_init_radius = config.get('sphere_init_radius', 0.5) + self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()] + for i in range(self.n_hidden_layers - 1): + self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()] + self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)] + self.layers = nn.Sequential(*self.layers) + self.output_activation = get_activation(config['output_activation']) + + @torch.cuda.amp.autocast(False) + def forward(self, x): + x = self.layers(x.float()) + x = self.output_activation(x) + return x + + def make_linear(self, dim_in, dim_out, is_first, is_last): + layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality + if self.sphere_init: + if is_last: + torch.nn.init.constant_(layer.bias, -self.sphere_init_radius) + torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001) + elif is_first: + torch.nn.init.constant_(layer.bias, 0.0) + torch.nn.init.constant_(layer.weight[:, 3:], 0.0) + torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out)) + else: + torch.nn.init.constant_(layer.bias, 0.0) + torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out)) + else: + torch.nn.init.constant_(layer.bias, 0.0) + torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu') + + if self.weight_norm: + layer = nn.utils.weight_norm(layer) + return layer + + def make_activation(self): + if self.sphere_init: + return nn.Softplus(beta=100) + else: + return nn.ReLU(inplace=True) + + +def sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network): + rank_zero_debug('Initialize tcnn MLP to approximately represent a sphere.') + """ + from https://github.com/NVlabs/tiny-cuda-nn/issues/96 + It's the weight matrices of each layer laid out in row-major order and then concatenated. + Notably: inputs and output dimensions are padded to multiples of 8 (CutlassMLP) or 16 (FullyFusedMLP). + The padded input dimensions get a constant value of 1.0, + whereas the padded output dimensions are simply ignored, + so the weights pertaining to those can have any value. + """ + padto = 16 if config.otype == 'FullyFusedMLP' else 8 + n_input_dims = n_input_dims + (padto - n_input_dims % padto) % padto + n_output_dims = n_output_dims + (padto - n_output_dims % padto) % padto + data = list(network.parameters())[0].data + assert data.shape[0] == (n_input_dims + n_output_dims) * config.n_neurons + (config.n_hidden_layers - 1) * config.n_neurons**2 + new_data = [] + # first layer + weight = torch.zeros((config.n_neurons, n_input_dims)).to(data) + torch.nn.init.constant_(weight[:, 3:], 0.0) + torch.nn.init.normal_(weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(config.n_neurons)) + new_data.append(weight.flatten()) + # hidden layers + for i in range(config.n_hidden_layers - 1): + weight = torch.zeros((config.n_neurons, config.n_neurons)).to(data) + torch.nn.init.normal_(weight, 0.0, math.sqrt(2) / math.sqrt(config.n_neurons)) + new_data.append(weight.flatten()) + # last layer + weight = torch.zeros((n_output_dims, config.n_neurons)).to(data) + torch.nn.init.normal_(weight, mean=math.sqrt(math.pi) / math.sqrt(config.n_neurons), std=0.0001) + new_data.append(weight.flatten()) + new_data = torch.cat(new_data) + data.copy_(new_data) + + +def get_mlp(n_input_dims, n_output_dims, config): + if config.otype == 'VanillaMLP': + network = VanillaMLP(n_input_dims, n_output_dims, config_to_primitive(config)) + else: + with torch.cuda.device(get_rank()): + network = tcnn.Network(n_input_dims, n_output_dims, config_to_primitive(config)) + if config.get('sphere_init', False): + sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network) + return network + + +class EncodingWithNetwork(nn.Module): + def __init__(self, encoding, network): + super().__init__() + self.encoding, self.network = encoding, network + + def forward(self, x): + return self.network(self.encoding(x)) + + def update_step(self, epoch, global_step): + update_module_step(self.encoding, epoch, global_step) + update_module_step(self.network, epoch, global_step) + + +def get_encoding_with_network(n_input_dims, n_output_dims, encoding_config, network_config): + # input suppose to be range [0, 1] + if encoding_config.otype in ['VanillaFrequency', 'ProgressiveBandHashGrid'] \ + or network_config.otype in ['VanillaMLP']: + encoding = get_encoding(n_input_dims, encoding_config) + network = get_mlp(encoding.n_output_dims, n_output_dims, network_config) + encoding_with_network = EncodingWithNetwork(encoding, network) + else: + with torch.cuda.device(get_rank()): + encoding_with_network = tcnn.NetworkWithInputEncoding( + n_input_dims=n_input_dims, + n_output_dims=n_output_dims, + encoding_config=config_to_primitive(encoding_config), + network_config=config_to_primitive(network_config) + ) + return encoding_with_network diff --git a/instant-nsr-pl/models/neus.py b/instant-nsr-pl/models/neus.py new file mode 100644 index 0000000000000000000000000000000000000000..51fb9480059165f91685003bc617548bbfc6c83d --- /dev/null +++ b/instant-nsr-pl/models/neus.py @@ -0,0 +1,341 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import models +from models.base import BaseModel +from models.utils import chunk_batch +from systems.utils import update_module_step +from nerfacc import ContractionType, OccupancyGrid, ray_marching, render_weight_from_density, render_weight_from_alpha, accumulate_along_rays +from nerfacc.intersection import ray_aabb_intersect + +import pdb + + +class VarianceNetwork(nn.Module): + def __init__(self, config): + super(VarianceNetwork, self).__init__() + self.config = config + self.init_val = self.config.init_val + self.register_parameter('variance', nn.Parameter(torch.tensor(self.config.init_val))) + self.modulate = self.config.get('modulate', False) + if self.modulate: + self.mod_start_steps = self.config.mod_start_steps + self.reach_max_steps = self.config.reach_max_steps + self.max_inv_s = self.config.max_inv_s + + @property + def inv_s(self): + val = torch.exp(self.variance * 10.0) + if self.modulate and self.do_mod: + val = val.clamp_max(self.mod_val) + return val + + def forward(self, x): + return torch.ones([len(x), 1], device=self.variance.device) * self.inv_s + + def update_step(self, epoch, global_step): + if self.modulate: + self.do_mod = global_step > self.mod_start_steps + if not self.do_mod: + self.prev_inv_s = self.inv_s.item() + else: + self.mod_val = min((global_step / self.reach_max_steps) * (self.max_inv_s - self.prev_inv_s) + self.prev_inv_s, self.max_inv_s) + + +@models.register('neus') +class NeuSModel(BaseModel): + def setup(self): + self.geometry = models.make(self.config.geometry.name, self.config.geometry) + self.texture = models.make(self.config.texture.name, self.config.texture) + self.geometry.contraction_type = ContractionType.AABB + + if self.config.learned_background: + self.geometry_bg = models.make(self.config.geometry_bg.name, self.config.geometry_bg) + self.texture_bg = models.make(self.config.texture_bg.name, self.config.texture_bg) + self.geometry_bg.contraction_type = ContractionType.UN_BOUNDED_SPHERE + self.near_plane_bg, self.far_plane_bg = 0.1, 1e3 + self.cone_angle_bg = 10**(math.log10(self.far_plane_bg) / self.config.num_samples_per_ray_bg) - 1. + self.render_step_size_bg = 0.01 + + self.variance = VarianceNetwork(self.config.variance) + self.register_buffer('scene_aabb', torch.as_tensor([-self.config.radius, -self.config.radius, -self.config.radius, self.config.radius, self.config.radius, self.config.radius], dtype=torch.float32)) + if self.config.grid_prune: + self.occupancy_grid = OccupancyGrid( + roi_aabb=self.scene_aabb, + resolution=128, + contraction_type=ContractionType.AABB + ) + if self.config.learned_background: + self.occupancy_grid_bg = OccupancyGrid( + roi_aabb=self.scene_aabb, + resolution=256, + contraction_type=ContractionType.UN_BOUNDED_SPHERE + ) + self.randomized = self.config.randomized + self.background_color = None + self.render_step_size = 1.732 * 2 * self.config.radius / self.config.num_samples_per_ray + + def update_step(self, epoch, global_step): + update_module_step(self.geometry, epoch, global_step) + update_module_step(self.texture, epoch, global_step) + if self.config.learned_background: + update_module_step(self.geometry_bg, epoch, global_step) + update_module_step(self.texture_bg, epoch, global_step) + update_module_step(self.variance, epoch, global_step) + + cos_anneal_end = self.config.get('cos_anneal_end', 0) + self.cos_anneal_ratio = 1.0 if cos_anneal_end == 0 else min(1.0, global_step / cos_anneal_end) + + def occ_eval_fn(x): + sdf = self.geometry(x, with_grad=False, with_feature=False) + inv_s = self.variance(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) + inv_s = inv_s.expand(sdf.shape[0], 1) + estimated_next_sdf = sdf[...,None] - self.render_step_size * 0.5 + estimated_prev_sdf = sdf[...,None] + self.render_step_size * 0.5 + prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) + next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) + p = prev_cdf - next_cdf + c = prev_cdf + alpha = ((p + 1e-5) / (c + 1e-5)).view(-1, 1).clip(0.0, 1.0) + return alpha + + def occ_eval_fn_bg(x): + density, _ = self.geometry_bg(x) + # approximate for 1 - torch.exp(-density[...,None] * self.render_step_size_bg) based on taylor series + return density[...,None] * self.render_step_size_bg + + if self.training and self.config.grid_prune: + self.occupancy_grid.every_n_step(step=global_step, occ_eval_fn=occ_eval_fn, occ_thre=self.config.get('grid_prune_occ_thre', 0.01)) + if self.config.learned_background: + self.occupancy_grid_bg.every_n_step(step=global_step, occ_eval_fn=occ_eval_fn_bg, occ_thre=self.config.get('grid_prune_occ_thre_bg', 0.01)) + + def isosurface(self): + mesh = self.geometry.isosurface() + return mesh + + def get_alpha(self, sdf, normal, dirs, dists): + inv_s = self.variance(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter + inv_s = inv_s.expand(sdf.shape[0], 1) + + true_cos = (dirs * normal).sum(-1, keepdim=True) + + # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes + # the cos value "not dead" at the beginning training iterations, for better convergence. + iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - self.cos_anneal_ratio) + + F.relu(-true_cos) * self.cos_anneal_ratio) # always non-positive + + # Estimate signed distances at section points + estimated_next_sdf = sdf[...,None] + iter_cos * dists.reshape(-1, 1) * 0.5 + estimated_prev_sdf = sdf[...,None] - iter_cos * dists.reshape(-1, 1) * 0.5 + + prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) + next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) + + p = prev_cdf - next_cdf + c = prev_cdf + + alpha = ((p + 1e-5) / (c + 1e-5)).view(-1).clip(0.0, 1.0) + return alpha + + def forward_bg_(self, rays): + n_rays = rays.shape[0] + rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) + + def sigma_fn(t_starts, t_ends, ray_indices): + ray_indices = ray_indices.long() + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + positions = t_origins + t_dirs * (t_starts + t_ends) / 2. + density, _ = self.geometry_bg(positions) + return density[...,None] + + _, t_max = ray_aabb_intersect(rays_o, rays_d, self.scene_aabb) + # if the ray intersects with the bounding box, start from the farther intersection point + # otherwise start from self.far_plane_bg + # note that in nerfacc t_max is set to 1e10 if there is no intersection + near_plane = torch.where(t_max > 1e9, self.near_plane_bg, t_max) + with torch.no_grad(): + ray_indices, t_starts, t_ends = ray_marching( + rays_o, rays_d, + scene_aabb=None, + grid=self.occupancy_grid_bg if self.config.grid_prune else None, + sigma_fn=sigma_fn, + near_plane=near_plane, far_plane=self.far_plane_bg, + render_step_size=self.render_step_size_bg, + stratified=self.randomized, + cone_angle=self.cone_angle_bg, + alpha_thre=0.0 + ) + + ray_indices = ray_indices.long() + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + midpoints = (t_starts + t_ends) / 2. + positions = t_origins + t_dirs * midpoints + intervals = t_ends - t_starts + + density, feature = self.geometry_bg(positions) + rgb = self.texture_bg(feature, t_dirs) + + weights = render_weight_from_density(t_starts, t_ends, density[...,None], ray_indices=ray_indices, n_rays=n_rays) + opacity = accumulate_along_rays(weights, ray_indices, values=None, n_rays=n_rays) + depth = accumulate_along_rays(weights, ray_indices, values=midpoints, n_rays=n_rays) + comp_rgb = accumulate_along_rays(weights, ray_indices, values=rgb, n_rays=n_rays) + comp_rgb = comp_rgb + self.background_color * (1.0 - opacity) + + out = { + 'comp_rgb': comp_rgb, + 'opacity': opacity, + 'depth': depth, + 'rays_valid': opacity > 0, + 'num_samples': torch.as_tensor([len(t_starts)], dtype=torch.int32, device=rays.device) + } + + if self.training: + out.update({ + 'weights': weights.view(-1), + 'points': midpoints.view(-1), + 'intervals': intervals.view(-1), + 'ray_indices': ray_indices.view(-1) + }) + + return out + + def forward_(self, rays): + n_rays = rays.shape[0] + rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) + + with torch.no_grad(): + ray_indices, t_starts, t_ends = ray_marching( + rays_o, rays_d, + scene_aabb=self.scene_aabb, + grid=self.occupancy_grid if self.config.grid_prune else None, + alpha_fn=None, + near_plane=None, far_plane=None, + render_step_size=self.render_step_size, + stratified=self.randomized, + cone_angle=0.0, + alpha_thre=0.0 + ) + + ray_indices = ray_indices.long() + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + midpoints = (t_starts + t_ends) / 2. + positions = t_origins + t_dirs * midpoints + dists = t_ends - t_starts + + if self.config.geometry.grad_type == 'finite_difference': + sdf, sdf_grad, feature, sdf_laplace = self.geometry(positions, with_grad=True, with_feature=True, with_laplace=True) + else: + sdf, sdf_grad, feature = self.geometry(positions, with_grad=True, with_feature=True) + + normal = F.normalize(sdf_grad, p=2, dim=-1) + alpha = self.get_alpha(sdf, normal, t_dirs, dists)[...,None] + rgb = self.texture(feature, t_dirs, normal) + + weights = render_weight_from_alpha(alpha, ray_indices=ray_indices, n_rays=n_rays) + opacity = accumulate_along_rays(weights, ray_indices, values=None, n_rays=n_rays) + depth = accumulate_along_rays(weights, ray_indices, values=midpoints, n_rays=n_rays) + comp_rgb = accumulate_along_rays(weights, ray_indices, values=rgb, n_rays=n_rays) + + comp_normal = accumulate_along_rays(weights, ray_indices, values=normal, n_rays=n_rays) + comp_normal = F.normalize(comp_normal, p=2, dim=-1) + + pts_random = torch.rand([1024*2, 3]).to(sdf.dtype).to(sdf.device) * 2 - 1 # normalized to (-1, 1) + + if self.config.geometry.grad_type == 'finite_difference': + random_sdf, random_sdf_grad, _ = self.geometry(pts_random, with_grad=True, with_feature=False, with_laplace=True) + _, normal_perturb, _ = self.geometry( + pts_random + torch.randn_like(pts_random) * 1e-2, + with_grad=True, with_feature=False, with_laplace=True + ) + else: + random_sdf, random_sdf_grad = self.geometry(pts_random, with_grad=True, with_feature=False) + _, normal_perturb = self.geometry(positions + torch.randn_like(positions) * 1e-2, + with_grad=True, with_feature=False,) + + # pdb.set_trace() + out = { + 'comp_rgb': comp_rgb, + 'comp_normal': comp_normal, + 'opacity': opacity, + 'depth': depth, + 'rays_valid': opacity > 0, + 'num_samples': torch.as_tensor([len(t_starts)], dtype=torch.int32, device=rays.device) + } + + if self.training: + out.update({ + 'sdf_samples': sdf, + 'sdf_grad_samples': sdf_grad, + 'random_sdf': random_sdf, + 'random_sdf_grad': random_sdf_grad, + 'normal_perturb' : normal_perturb, + 'weights': weights.view(-1), + 'points': midpoints.view(-1), + 'intervals': dists.view(-1), + 'ray_indices': ray_indices.view(-1) + }) + if self.config.geometry.grad_type == 'finite_difference': + out.update({ + 'sdf_laplace_samples': sdf_laplace + }) + + if self.config.learned_background: + out_bg = self.forward_bg_(rays) + else: + out_bg = { + 'comp_rgb': self.background_color[None,:].expand(*comp_rgb.shape), + 'num_samples': torch.zeros_like(out['num_samples']), + 'rays_valid': torch.zeros_like(out['rays_valid']) + } + + out_full = { + 'comp_rgb': out['comp_rgb'] + out_bg['comp_rgb'] * (1.0 - out['opacity']), + 'num_samples': out['num_samples'] + out_bg['num_samples'], + 'rays_valid': out['rays_valid'] | out_bg['rays_valid'] + } + + return { + **out, + **{k + '_bg': v for k, v in out_bg.items()}, + **{k + '_full': v for k, v in out_full.items()} + } + + def forward(self, rays): + if self.training: + out = self.forward_(rays) + else: + out = chunk_batch(self.forward_, self.config.ray_chunk, True, rays) + return { + **out, + 'inv_s': self.variance.inv_s + } + + def train(self, mode=True): + self.randomized = mode and self.config.randomized + return super().train(mode=mode) + + def eval(self): + self.randomized = False + return super().eval() + + def regularizations(self, out): + losses = {} + losses.update(self.geometry.regularizations(out)) + losses.update(self.texture.regularizations(out)) + return losses + + @torch.no_grad() + def export(self, export_config): + mesh = self.isosurface() + if export_config.export_vertex_color: + _, sdf_grad, feature = chunk_batch(self.geometry, export_config.chunk_size, False, mesh['v_pos'].to(self.rank), with_grad=True, with_feature=True) + normal = F.normalize(sdf_grad, p=2, dim=-1) + rgb = self.texture(feature, -normal, normal) # set the viewing directions to the normal to get "albedo" + mesh['v_rgb'] = rgb.cpu() + return mesh diff --git a/instant-nsr-pl/models/ray_utils.py b/instant-nsr-pl/models/ray_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ca1866fa43aedb83e111233af1c5d0e37dbedf75 --- /dev/null +++ b/instant-nsr-pl/models/ray_utils.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +def cast_rays(ori, dir, z_vals): + return ori[..., None, :] + z_vals[..., None] * dir[..., None, :] + + +def get_ray_directions(W, H, fx, fy, cx, cy, use_pixel_centers=True): + pixel_center = 0.5 if use_pixel_centers else 0 + i, j = np.meshgrid( + np.arange(W, dtype=np.float32) + pixel_center, + np.arange(H, dtype=np.float32) + pixel_center, + indexing='xy' + ) + i, j = torch.from_numpy(i), torch.from_numpy(j) + + # directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1) # (H, W, 3) + # opencv system + directions = torch.stack([(i - cx) / fx, (j - cy) / fy, torch.ones_like(i)], -1) # (H, W, 3) + + return directions + + +def get_ortho_ray_directions_origins(W, H, use_pixel_centers=True): + pixel_center = 0.5 if use_pixel_centers else 0 + i, j = np.meshgrid( + np.arange(W, dtype=np.float32) + pixel_center, + np.arange(H, dtype=np.float32) + pixel_center, + indexing='xy' + ) + i, j = torch.from_numpy(i), torch.from_numpy(j) + + origins = torch.stack([(i/W-0.5)*2, (j/H-0.5)*2, torch.zeros_like(i)], dim=-1) # W, H, 3 + directions = torch.stack([torch.zeros_like(i), torch.zeros_like(j), torch.ones_like(i)], dim=-1) # W, H, 3 + + return origins, directions + + +def get_rays(directions, c2w, keepdim=False): + # Rotate ray directions from camera coordinate to the world coordinate + # rays_d = directions @ c2w[:, :3].T # (H, W, 3) # slow? + assert directions.shape[-1] == 3 + + if directions.ndim == 2: # (N_rays, 3) + assert c2w.ndim == 3 # (N_rays, 4, 4) / (1, 4, 4) + rays_d = (directions[:,None,:] * c2w[:,:3,:3]).sum(-1) # (N_rays, 3) + rays_o = c2w[:,:,3].expand(rays_d.shape) + elif directions.ndim == 3: # (H, W, 3) + if c2w.ndim == 2: # (4, 4) + rays_d = (directions[:,:,None,:] * c2w[None,None,:3,:3]).sum(-1) # (H, W, 3) + rays_o = c2w[None,None,:,3].expand(rays_d.shape) + elif c2w.ndim == 3: # (B, 4, 4) + rays_d = (directions[None,:,:,None,:] * c2w[:,None,None,:3,:3]).sum(-1) # (B, H, W, 3) + rays_o = c2w[:,None,None,:,3].expand(rays_d.shape) + + if not keepdim: + rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) + + return rays_o, rays_d + + +# rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3].cuda(), rays_v[:, :, :, None].cuda()).squeeze() # W, H, 3 + +# rays_o = torch.matmul(self.pose_all[img_idx, None, None, :3, :3].cuda(), q[:, :, :, None].cuda()).squeeze() # W, H, 3 +# rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape).cuda() + rays_o # W, H, 3 + +def get_ortho_rays(origins, directions, c2w, keepdim=False): + # Rotate ray directions from camera coordinate to the world coordinate + # rays_d = directions @ c2w[:, :3].T # (H, W, 3) # slow? + assert directions.shape[-1] == 3 + assert origins.shape[-1] == 3 + + if directions.ndim == 2: # (N_rays, 3) + assert c2w.ndim == 3 # (N_rays, 4, 4) / (1, 4, 4) + rays_d = torch.matmul(c2w[:, :3, :3], directions[:, :, None]).squeeze() # (N_rays, 3) + rays_o = torch.matmul(c2w[:, :3, :3], origins[:, :, None]).squeeze() # (N_rays, 3) + rays_o = c2w[:,:3,3].expand(rays_d.shape) + rays_o + elif directions.ndim == 3: # (H, W, 3) + if c2w.ndim == 2: # (4, 4) + rays_d = torch.matmul(c2w[None, None, :3, :3], directions[:, :, :, None]).squeeze() # (H, W, 3) + rays_o = torch.matmul(c2w[None, None, :3, :3], origins[:, :, :, None]).squeeze() # (H, W, 3) + rays_o = c2w[None, None,:3,3].expand(rays_d.shape) + rays_o + elif c2w.ndim == 3: # (B, 4, 4) + rays_d = torch.matmul(c2w[:,None, None, :3, :3], directions[None, :, :, :, None]).squeeze() # # (B, H, W, 3) + rays_o = torch.matmul(c2w[:,None, None, :3, :3], origins[None, :, :, :, None]).squeeze() # # (B, H, W, 3) + rays_o = c2w[:,None, None, :3,3].expand(rays_d.shape) + rays_o + + if not keepdim: + rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) + + return rays_o, rays_d diff --git a/instant-nsr-pl/models/texture.py b/instant-nsr-pl/models/texture.py new file mode 100644 index 0000000000000000000000000000000000000000..4a83c9775c89d812cf6009155a414771c5462ebf --- /dev/null +++ b/instant-nsr-pl/models/texture.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn + +import models +from models.utils import get_activation +from models.network_utils import get_encoding, get_mlp +from systems.utils import update_module_step + + +@models.register('volume-radiance') +class VolumeRadiance(nn.Module): + def __init__(self, config): + super(VolumeRadiance, self).__init__() + self.config = config + self.with_viewdir = False #self.config.get('wo_viewdir', False) + self.n_dir_dims = self.config.get('n_dir_dims', 3) + self.n_output_dims = 3 + + if self.with_viewdir: + encoding = get_encoding(self.n_dir_dims, self.config.dir_encoding_config) + self.n_input_dims = self.config.input_feature_dim + encoding.n_output_dims + # self.network_base = get_mlp(self.config.input_feature_dim, self.n_output_dims, self.config.mlp_network_config) + else: + encoding = None + self.n_input_dims = self.config.input_feature_dim + + network = get_mlp(self.n_input_dims, self.n_output_dims, self.config.mlp_network_config) + self.encoding = encoding + self.network = network + + def forward(self, features, dirs, *args): + + # features = features.detach() + if self.with_viewdir: + dirs = (dirs + 1.) / 2. # (-1, 1) => (0, 1) + dirs_embd = self.encoding(dirs.view(-1, self.n_dir_dims)) + network_inp = torch.cat([features.view(-1, features.shape[-1]), dirs_embd] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1) + # network_inp_base = torch.cat([features.view(-1, features.shape[-1])] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1) + color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() + # color_base = self.network_base(network_inp_base).view(*features.shape[:-1], self.n_output_dims).float() + # color = color + color_base + else: + network_inp = torch.cat([features.view(-1, features.shape[-1])] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1) + color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() + + if 'color_activation' in self.config: + color = get_activation(self.config.color_activation)(color) + return color + + def update_step(self, epoch, global_step): + update_module_step(self.encoding, epoch, global_step) + + def regularizations(self, out): + return {} + + +@models.register('volume-color') +class VolumeColor(nn.Module): + def __init__(self, config): + super(VolumeColor, self).__init__() + self.config = config + self.n_output_dims = 3 + self.n_input_dims = self.config.input_feature_dim + network = get_mlp(self.n_input_dims, self.n_output_dims, self.config.mlp_network_config) + self.network = network + + def forward(self, features, *args): + network_inp = features.view(-1, features.shape[-1]) + color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() + if 'color_activation' in self.config: + color = get_activation(self.config.color_activation)(color) + return color + + def regularizations(self, out): + return {} diff --git a/instant-nsr-pl/models/utils.py b/instant-nsr-pl/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5c3cf19dd3e8f277783db68f1435c8f9755e96 --- /dev/null +++ b/instant-nsr-pl/models/utils.py @@ -0,0 +1,119 @@ +import gc +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +import tinycudann as tcnn + + +def chunk_batch(func, chunk_size, move_to_cpu, *args, **kwargs): + B = None + for arg in args: + if isinstance(arg, torch.Tensor): + B = arg.shape[0] + break + out = defaultdict(list) + out_type = None + for i in range(0, B, chunk_size): + out_chunk = func(*[arg[i:i+chunk_size] if isinstance(arg, torch.Tensor) else arg for arg in args], **kwargs) + if out_chunk is None: + continue + out_type = type(out_chunk) + if isinstance(out_chunk, torch.Tensor): + out_chunk = {0: out_chunk} + elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): + chunk_length = len(out_chunk) + out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} + elif isinstance(out_chunk, dict): + pass + else: + print(f'Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}.') + exit(1) + for k, v in out_chunk.items(): + v = v if torch.is_grad_enabled() else v.detach() + v = v.cpu() if move_to_cpu else v + out[k].append(v) + + if out_type is None: + return + + out = {k: torch.cat(v, dim=0) for k, v in out.items()} + if out_type is torch.Tensor: + return out[0] + elif out_type in [tuple, list]: + return out_type([out[i] for i in range(chunk_length)]) + elif out_type is dict: + return out + + +class _TruncExp(Function): # pylint: disable=abstract-method + # Implementation from torch-ngp: + # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, x): # pylint: disable=arguments-differ + ctx.save_for_backward(x) + return torch.exp(x) + + @staticmethod + @custom_bwd + def backward(ctx, g): # pylint: disable=arguments-differ + x = ctx.saved_tensors[0] + return g * torch.exp(torch.clamp(x, max=15)) + +trunc_exp = _TruncExp.apply + + +def get_activation(name): + if name is None: + return lambda x: x + name = name.lower() + if name == 'none': + return lambda x: x + elif name.startswith('scale'): + scale_factor = float(name[5:]) + return lambda x: x.clamp(0., scale_factor) / scale_factor + elif name.startswith('clamp'): + clamp_max = float(name[5:]) + return lambda x: x.clamp(0., clamp_max) + elif name.startswith('mul'): + mul_factor = float(name[3:]) + return lambda x: x * mul_factor + elif name == 'lin2srgb': + return lambda x: torch.where(x > 0.0031308, torch.pow(torch.clamp(x, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*x).clamp(0., 1.) + elif name == 'trunc_exp': + return trunc_exp + elif name.startswith('+') or name.startswith('-'): + return lambda x: x + float(name) + elif name == 'sigmoid': + return lambda x: torch.sigmoid(x) + elif name == 'tanh': + return lambda x: torch.tanh(x) + else: + return getattr(F, name) + + +def dot(x, y): + return torch.sum(x*y, -1, keepdim=True) + + +def reflect(x, n): + return 2 * dot(x, n) * n - x + + +def scale_anything(dat, inp_scale, tgt_scale): + if inp_scale is None: + inp_scale = [dat.min(), dat.max()] + dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) + dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] + return dat + + +def cleanup(): + gc.collect() + torch.cuda.empty_cache() + tcnn.free_temporary_memory() diff --git a/instant-nsr-pl/requirements.txt b/instant-nsr-pl/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3ab886c0d76248bfd1035579b460292458ef25e6 --- /dev/null +++ b/instant-nsr-pl/requirements.txt @@ -0,0 +1,12 @@ +pytorch-lightning<2 +omegaconf==2.2.3 +nerfacc==0.3.3 +matplotlib +opencv-python +imageio +imageio-ffmpeg +scipy +PyMCubes +pyransac3d +torch_efficient_distloss +tensorboard diff --git a/instant-nsr-pl/run.sh b/instant-nsr-pl/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..617143ccecba268e77b2aeb48cb3ec266d098c40 --- /dev/null +++ b/instant-nsr-pl/run.sh @@ -0,0 +1 @@ +python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=$1 dataset.scene=$2 \ No newline at end of file diff --git a/instant-nsr-pl/scripts/imgs2poses.py b/instant-nsr-pl/scripts/imgs2poses.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b6e0b19c7192fceee0518b2cde691bfabd4ff4 --- /dev/null +++ b/instant-nsr-pl/scripts/imgs2poses.py @@ -0,0 +1,85 @@ + +""" +This file is adapted from https://github.com/Fyusion/LLFF. +""" + +import os +import sys +import argparse +import subprocess + + +def run_colmap(basedir, match_type): + logfile_name = os.path.join(basedir, 'colmap_output.txt') + logfile = open(logfile_name, 'w') + + feature_extractor_args = [ + 'colmap', 'feature_extractor', + '--database_path', os.path.join(basedir, 'database.db'), + '--image_path', os.path.join(basedir, 'images'), + '--ImageReader.single_camera', '1' + ] + feat_output = ( subprocess.check_output(feature_extractor_args, universal_newlines=True) ) + logfile.write(feat_output) + print('Features extracted') + + exhaustive_matcher_args = [ + 'colmap', match_type, + '--database_path', os.path.join(basedir, 'database.db'), + ] + + match_output = ( subprocess.check_output(exhaustive_matcher_args, universal_newlines=True) ) + logfile.write(match_output) + print('Features matched') + + p = os.path.join(basedir, 'sparse') + if not os.path.exists(p): + os.makedirs(p) + + mapper_args = [ + 'colmap', 'mapper', + '--database_path', os.path.join(basedir, 'database.db'), + '--image_path', os.path.join(basedir, 'images'), + '--output_path', os.path.join(basedir, 'sparse'), # --export_path changed to --output_path in colmap 3.6 + '--Mapper.num_threads', '16', + '--Mapper.init_min_tri_angle', '4', + '--Mapper.multiple_models', '0', + '--Mapper.extract_colors', '0', + ] + + map_output = ( subprocess.check_output(mapper_args, universal_newlines=True) ) + logfile.write(map_output) + logfile.close() + print('Sparse map created') + + print( 'Finished running COLMAP, see {} for logs'.format(logfile_name) ) + + +def gen_poses(basedir, match_type): + files_needed = ['{}.bin'.format(f) for f in ['cameras', 'images', 'points3D']] + if os.path.exists(os.path.join(basedir, 'sparse/0')): + files_had = os.listdir(os.path.join(basedir, 'sparse/0')) + else: + files_had = [] + if not all([f in files_had for f in files_needed]): + print( 'Need to run COLMAP' ) + run_colmap(basedir, match_type) + else: + print('Don\'t need to run COLMAP') + + return True + + +if __name__=='__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--match_type', type=str, + default='exhaustive_matcher', help='type of matcher used. Valid options: \ + exhaustive_matcher sequential_matcher. Other matchers not supported at this time') + parser.add_argument('scenedir', type=str, + help='input scene directory') + args = parser.parse_args() + + if args.match_type != 'exhaustive_matcher' and args.match_type != 'sequential_matcher': + print('ERROR: matcher type ' + args.match_type + ' is not valid. Aborting') + sys.exit() + gen_poses(args.scenedir, args.match_type) diff --git a/instant-nsr-pl/systems/__init__.py b/instant-nsr-pl/systems/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df114d15f57bfcceb5b626be4c97a8d4c442cee8 --- /dev/null +++ b/instant-nsr-pl/systems/__init__.py @@ -0,0 +1,19 @@ +systems = {} + + +def register(name): + def decorator(cls): + systems[name] = cls + return cls + return decorator + + +def make(name, config, load_from_checkpoint=None): + if load_from_checkpoint is None: + system = systems[name](config) + else: + system = systems[name].load_from_checkpoint(load_from_checkpoint, strict=False, config=config) + return system + + +from . import neus, neus_ortho, neus_pinhole diff --git a/instant-nsr-pl/systems/base.py b/instant-nsr-pl/systems/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bcdbdc76d810548f85ebbaf64870a33f5ddaf1 --- /dev/null +++ b/instant-nsr-pl/systems/base.py @@ -0,0 +1,128 @@ +import pytorch_lightning as pl + +import models +from systems.utils import parse_optimizer, parse_scheduler, update_module_step +from utils.mixins import SaverMixin +from utils.misc import config_to_primitive, get_rank + + +class BaseSystem(pl.LightningModule, SaverMixin): + """ + Two ways to print to console: + 1. self.print: correctly handle progress bar + 2. rank_zero_info: use the logging module + """ + def __init__(self, config): + super().__init__() + self.config = config + self.rank = get_rank() + self.prepare() + self.model = models.make(self.config.model.name, self.config.model) + + def prepare(self): + pass + + def forward(self, batch): + raise NotImplementedError + + def C(self, value): + if isinstance(value, int) or isinstance(value, float): + pass + else: + value = config_to_primitive(value) + if not isinstance(value, list): + raise TypeError('Scalar specification only supports list, got', type(value)) + if len(value) == 3: + value = [0] + value + assert len(value) == 4 + start_step, start_value, end_value, end_step = value + if isinstance(end_step, int): + current_step = self.global_step + value = start_value + (end_value - start_value) * max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0) + elif isinstance(end_step, float): + current_step = self.current_epoch + value = start_value + (end_value - start_value) * max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0) + return value + + def preprocess_data(self, batch, stage): + pass + + """ + Implementing on_after_batch_transfer of DataModule does the same. + But on_after_batch_transfer does not support DP. + """ + def on_train_batch_start(self, batch, batch_idx, unused=0): + self.dataset = self.trainer.datamodule.train_dataloader().dataset + self.preprocess_data(batch, 'train') + update_module_step(self.model, self.current_epoch, self.global_step) + + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): + self.dataset = self.trainer.datamodule.val_dataloader().dataset + self.preprocess_data(batch, 'validation') + update_module_step(self.model, self.current_epoch, self.global_step) + + def on_test_batch_start(self, batch, batch_idx, dataloader_idx): + self.dataset = self.trainer.datamodule.test_dataloader().dataset + self.preprocess_data(batch, 'test') + update_module_step(self.model, self.current_epoch, self.global_step) + + def on_predict_batch_start(self, batch, batch_idx, dataloader_idx): + self.dataset = self.trainer.datamodule.predict_dataloader().dataset + self.preprocess_data(batch, 'predict') + update_module_step(self.model, self.current_epoch, self.global_step) + + def training_step(self, batch, batch_idx): + raise NotImplementedError + + """ + # aggregate outputs from different devices (DP) + def training_step_end(self, out): + pass + """ + + """ + # aggregate outputs from different iterations + def training_epoch_end(self, out): + pass + """ + + def validation_step(self, batch, batch_idx): + raise NotImplementedError + + """ + # aggregate outputs from different devices when using DP + def validation_step_end(self, out): + pass + """ + + def validation_epoch_end(self, out): + """ + Gather metrics from all devices, compute mean. + Purge repeated results using data index. + """ + raise NotImplementedError + + def test_step(self, batch, batch_idx): + raise NotImplementedError + + def test_epoch_end(self, out): + """ + Gather metrics from all devices, compute mean. + Purge repeated results using data index. + """ + raise NotImplementedError + + def export(self): + raise NotImplementedError + + def configure_optimizers(self): + optim = parse_optimizer(self.config.system.optimizer, self.model) + ret = { + 'optimizer': optim, + } + if 'scheduler' in self.config.system: + ret.update({ + 'lr_scheduler': parse_scheduler(self.config.system.scheduler, optim), + }) + return ret + diff --git a/instant-nsr-pl/systems/criterions.py b/instant-nsr-pl/systems/criterions.py new file mode 100644 index 0000000000000000000000000000000000000000..b101032ec7bc8d9943dd5df47557c4b6d3aa465b --- /dev/null +++ b/instant-nsr-pl/systems/criterions.py @@ -0,0 +1,164 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class WeightedLoss(nn.Module): + @property + def func(self): + raise NotImplementedError + + def forward(self, inputs, targets, weight=None, reduction='mean'): + assert reduction in ['none', 'sum', 'mean', 'valid_mean'] + loss = self.func(inputs, targets, reduction='none') + if weight is not None: + while weight.ndim < inputs.ndim: + weight = weight[..., None] + loss *= weight.float() + if reduction == 'none': + return loss + elif reduction == 'sum': + return loss.sum() + elif reduction == 'mean': + return loss.mean() + elif reduction == 'valid_mean': + return loss.sum() / weight.float().sum() + + +class MSELoss(WeightedLoss): + @property + def func(self): + return F.mse_loss + + +class L1Loss(WeightedLoss): + @property + def func(self): + return F.l1_loss + + +class PSNR(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, inputs, targets, valid_mask=None, reduction='mean'): + assert reduction in ['mean', 'none'] + value = (inputs - targets)**2 + if valid_mask is not None: + value = value[valid_mask] + if reduction == 'mean': + return -10 * torch.log10(torch.mean(value)) + elif reduction == 'none': + return -10 * torch.log10(torch.mean(value, dim=tuple(range(value.ndim)[1:]))) + + +class SSIM(): + def __init__(self, data_range=(0, 1), kernel_size=(11, 11), sigma=(1.5, 1.5), k1=0.01, k2=0.03, gaussian=True): + self.kernel_size = kernel_size + self.sigma = sigma + self.gaussian = gaussian + + if any(x % 2 == 0 or x <= 0 for x in self.kernel_size): + raise ValueError(f"Expected kernel_size to have odd positive number. Got {kernel_size}.") + if any(y <= 0 for y in self.sigma): + raise ValueError(f"Expected sigma to have positive number. Got {sigma}.") + + data_scale = data_range[1] - data_range[0] + self.c1 = (k1 * data_scale)**2 + self.c2 = (k2 * data_scale)**2 + self.pad_h = (self.kernel_size[0] - 1) // 2 + self.pad_w = (self.kernel_size[1] - 1) // 2 + self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma) + + def _uniform(self, kernel_size): + max, min = 2.5, -2.5 + ksize_half = (kernel_size - 1) * 0.5 + kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + for i, j in enumerate(kernel): + if min <= j <= max: + kernel[i] = 1 / (max - min) + else: + kernel[i] = 0 + + return kernel.unsqueeze(dim=0) # (1, kernel_size) + + def _gaussian(self, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + gauss = torch.exp(-0.5 * (kernel / sigma).pow(2)) + return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) + + def _gaussian_or_uniform_kernel(self, kernel_size, sigma): + if self.gaussian: + kernel_x = self._gaussian(kernel_size[0], sigma[0]) + kernel_y = self._gaussian(kernel_size[1], sigma[1]) + else: + kernel_x = self._uniform(kernel_size[0]) + kernel_y = self._uniform(kernel_size[1]) + + return torch.matmul(kernel_x.t(), kernel_y) # (kernel_size, 1) * (1, kernel_size) + + def __call__(self, output, target, reduction='mean'): + if output.dtype != target.dtype: + raise TypeError( + f"Expected output and target to have the same data type. Got output: {output.dtype} and y: {target.dtype}." + ) + + if output.shape != target.shape: + raise ValueError( + f"Expected output and target to have the same shape. Got output: {output.shape} and y: {target.shape}." + ) + + if len(output.shape) != 4 or len(target.shape) != 4: + raise ValueError( + f"Expected output and target to have BxCxHxW shape. Got output: {output.shape} and y: {target.shape}." + ) + + assert reduction in ['mean', 'sum', 'none'] + + channel = output.size(1) + if len(self._kernel.shape) < 4: + self._kernel = self._kernel.expand(channel, 1, -1, -1) + + output = F.pad(output, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect") + target = F.pad(target, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect") + + input_list = torch.cat([output, target, output * output, target * target, output * target]) + outputs = F.conv2d(input_list, self._kernel, groups=channel) + + output_list = [outputs[x * output.size(0) : (x + 1) * output.size(0)] for x in range(len(outputs))] + + mu_pred_sq = output_list[0].pow(2) + mu_target_sq = output_list[1].pow(2) + mu_pred_target = output_list[0] * output_list[1] + + sigma_pred_sq = output_list[2] - mu_pred_sq + sigma_target_sq = output_list[3] - mu_target_sq + sigma_pred_target = output_list[4] - mu_pred_target + + a1 = 2 * mu_pred_target + self.c1 + a2 = 2 * sigma_pred_target + self.c2 + b1 = mu_pred_sq + mu_target_sq + self.c1 + b2 = sigma_pred_sq + sigma_target_sq + self.c2 + + ssim_idx = (a1 * a2) / (b1 * b2) + _ssim = torch.mean(ssim_idx, (1, 2, 3)) + + if reduction == 'none': + return _ssim + elif reduction == 'sum': + return _ssim.sum() + elif reduction == 'mean': + return _ssim.mean() + + +def binary_cross_entropy(input, target, reduction='mean'): + """ + F.binary_cross_entropy is not numerically stable in mixed-precision training. + """ + loss = -(target * torch.log(input) + (1 - target) * torch.log(1 - input)) + + if reduction == 'mean': + return loss.mean() + elif reduction == 'none': + return loss diff --git a/instant-nsr-pl/systems/nerf.py b/instant-nsr-pl/systems/nerf.py new file mode 100644 index 0000000000000000000000000000000000000000..c5fc821f430aeee62880b7240e75a185ca9b15f2 --- /dev/null +++ b/instant-nsr-pl/systems/nerf.py @@ -0,0 +1,218 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_efficient_distloss import flatten_eff_distloss + +import pytorch_lightning as pl +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug + +import models +from models.ray_utils import get_rays +import systems +from systems.base import BaseSystem +from systems.criterions import PSNR + + +@systems.register('nerf-system') +class NeRFSystem(BaseSystem): + """ + Two ways to print to console: + 1. self.print: correctly handle progress bar + 2. rank_zero_info: use the logging module + """ + def prepare(self): + self.criterions = { + 'psnr': PSNR() + } + self.train_num_samples = self.config.model.train_num_rays * self.config.model.num_samples_per_ray + self.train_num_rays = self.config.model.train_num_rays + + def forward(self, batch): + return self.model(batch['rays']) + + def preprocess_data(self, batch, stage): + if 'index' in batch: # validation / testing + index = batch['index'] + else: + if self.config.model.batch_image_sampling: + index = torch.randint(0, len(self.dataset.all_images), size=(self.train_num_rays,), device=self.dataset.all_images.device) + else: + index = torch.randint(0, len(self.dataset.all_images), size=(1,), device=self.dataset.all_images.device) + if stage in ['train']: + c2w = self.dataset.all_c2w[index] + x = torch.randint( + 0, self.dataset.w, size=(self.train_num_rays,), device=self.dataset.all_images.device + ) + y = torch.randint( + 0, self.dataset.h, size=(self.train_num_rays,), device=self.dataset.all_images.device + ) + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions[y, x] + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index, y, x] + rays_o, rays_d = get_rays(directions, c2w) + rgb = self.dataset.all_images[index, y, x].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + fg_mask = self.dataset.all_fg_masks[index, y, x].view(-1).to(self.rank) + else: + c2w = self.dataset.all_c2w[index][0] + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index][0] + rays_o, rays_d = get_rays(directions, c2w) + rgb = self.dataset.all_images[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank) + + rays = torch.cat([rays_o, F.normalize(rays_d, p=2, dim=-1)], dim=-1) + + if stage in ['train']: + if self.config.model.background_color == 'white': + self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) + elif self.config.model.background_color == 'random': + self.model.background_color = torch.rand((3,), dtype=torch.float32, device=self.rank) + else: + raise NotImplementedError + else: + self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) + + if self.dataset.apply_mask: + rgb = rgb * fg_mask[...,None] + self.model.background_color * (1 - fg_mask[...,None]) + + batch.update({ + 'rays': rays, + 'rgb': rgb, + 'fg_mask': fg_mask + }) + + def training_step(self, batch, batch_idx): + out = self(batch) + + loss = 0. + + # update train_num_rays + if self.config.model.dynamic_ray_sampling: + train_num_rays = int(self.train_num_rays * (self.train_num_samples / out['num_samples'].sum().item())) + self.train_num_rays = min(int(self.train_num_rays * 0.9 + train_num_rays * 0.1), self.config.model.max_train_num_rays) + + loss_rgb = F.smooth_l1_loss(out['comp_rgb'][out['rays_valid'][...,0]], batch['rgb'][out['rays_valid'][...,0]]) + self.log('train/loss_rgb', loss_rgb) + loss += loss_rgb * self.C(self.config.system.loss.lambda_rgb) + + # distortion loss proposed in MipNeRF360 + # an efficient implementation from https://github.com/sunset1995/torch_efficient_distloss, but still slows down training by ~30% + if self.C(self.config.system.loss.lambda_distortion) > 0: + loss_distortion = flatten_eff_distloss(out['weights'], out['points'], out['intervals'], out['ray_indices']) + self.log('train/loss_distortion', loss_distortion) + loss += loss_distortion * self.C(self.config.system.loss.lambda_distortion) + + losses_model_reg = self.model.regularizations(out) + for name, value in losses_model_reg.items(): + self.log(f'train/loss_{name}', value) + loss_ = value * self.C(self.config.system.loss[f"lambda_{name}"]) + loss += loss_ + + for name, value in self.config.system.loss.items(): + if name.startswith('lambda'): + self.log(f'train_params/{name}', self.C(value)) + + self.log('train/num_rays', float(self.train_num_rays), prog_bar=True) + + return { + 'loss': loss + } + + """ + # aggregate outputs from different devices (DP) + def training_step_end(self, out): + pass + """ + + """ + # aggregate outputs from different iterations + def training_epoch_end(self, out): + pass + """ + + def validation_step(self, batch, batch_idx): + out = self(batch) + psnr = self.criterions['psnr'](out['comp_rgb'].to(batch['rgb']), batch['rgb']) + W, H = self.dataset.img_wh + self.save_image_grid(f"it{self.global_step}-{batch['index'][0].item()}.png", [ + {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, + {'type': 'grayscale', 'img': out['opacity'].view(H, W), 'kwargs': {'cmap': None, 'data_range': (0, 1)}} + ]) + return { + 'psnr': psnr, + 'index': batch['index'] + } + + + """ + # aggregate outputs from different devices when using DP + def validation_step_end(self, out): + pass + """ + + def validation_epoch_end(self, out): + out = self.all_gather(out) + if self.trainer.is_global_zero: + out_set = {} + for step_out in out: + # DP + if step_out['index'].ndim == 1: + out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # DDP + else: + for oi, index in enumerate(step_out['index']): + out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) + self.log('val/psnr', psnr, prog_bar=True, rank_zero_only=True) + + def test_step(self, batch, batch_idx): + out = self(batch) + psnr = self.criterions['psnr'](out['comp_rgb'].to(batch['rgb']), batch['rgb']) + W, H = self.dataset.img_wh + self.save_image_grid(f"it{self.global_step}-test/{batch['index'][0].item()}.png", [ + {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, + {'type': 'grayscale', 'img': out['opacity'].view(H, W), 'kwargs': {'cmap': None, 'data_range': (0, 1)}} + ]) + return { + 'psnr': psnr, + 'index': batch['index'] + } + + def test_epoch_end(self, out): + out = self.all_gather(out) + if self.trainer.is_global_zero: + out_set = {} + for step_out in out: + # DP + if step_out['index'].ndim == 1: + out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # DDP + else: + for oi, index in enumerate(step_out['index']): + out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) + self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True) + + self.save_img_sequence( + f"it{self.global_step}-test", + f"it{self.global_step}-test", + '(\d+)\.png', + save_format='mp4', + fps=30 + ) + + self.export() + + def export(self): + mesh = self.model.export(self.config.export) + self.save_mesh( + f"it{self.global_step}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj", + **mesh + ) diff --git a/instant-nsr-pl/systems/neus.py b/instant-nsr-pl/systems/neus.py new file mode 100644 index 0000000000000000000000000000000000000000..ce273d0790a1fbea4d795b07e285c1318573a562 --- /dev/null +++ b/instant-nsr-pl/systems/neus.py @@ -0,0 +1,265 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_efficient_distloss import flatten_eff_distloss + +import pytorch_lightning as pl +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug + +import models +from models.utils import cleanup +from models.ray_utils import get_rays +import systems +from systems.base import BaseSystem +from systems.criterions import PSNR, binary_cross_entropy + + +@systems.register('neus-system') +class NeuSSystem(BaseSystem): + """ + Two ways to print to console: + 1. self.print: correctly handle progress bar + 2. rank_zero_info: use the logging module + """ + def prepare(self): + self.criterions = { + 'psnr': PSNR() + } + self.train_num_samples = self.config.model.train_num_rays * (self.config.model.num_samples_per_ray + self.config.model.get('num_samples_per_ray_bg', 0)) + self.train_num_rays = self.config.model.train_num_rays + + def forward(self, batch): + return self.model(batch['rays']) + + def preprocess_data(self, batch, stage): + if 'index' in batch: # validation / testing + index = batch['index'] + else: + if self.config.model.batch_image_sampling: + index = torch.randint(0, len(self.dataset.all_images), size=(self.train_num_rays,), device=self.dataset.all_images.device) + else: + index = torch.randint(0, len(self.dataset.all_images), size=(1,), device=self.dataset.all_images.device) + if stage in ['train']: + c2w = self.dataset.all_c2w[index] + x = torch.randint( + 0, self.dataset.w, size=(self.train_num_rays,), device=self.dataset.all_images.device + ) + y = torch.randint( + 0, self.dataset.h, size=(self.train_num_rays,), device=self.dataset.all_images.device + ) + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions[y, x] + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index, y, x] + rays_o, rays_d = get_rays(directions, c2w) + rgb = self.dataset.all_images[index, y, x].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + fg_mask = self.dataset.all_fg_masks[index, y, x].view(-1).to(self.rank) + else: + c2w = self.dataset.all_c2w[index][0] + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index][0] + rays_o, rays_d = get_rays(directions, c2w) + rgb = self.dataset.all_images[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank) + + rays = torch.cat([rays_o, F.normalize(rays_d, p=2, dim=-1)], dim=-1) + + if stage in ['train']: + if self.config.model.background_color == 'white': + self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) + elif self.config.model.background_color == 'random': + self.model.background_color = torch.rand((3,), dtype=torch.float32, device=self.rank) + else: + raise NotImplementedError + else: + self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) + + if self.dataset.apply_mask: + rgb = rgb * fg_mask[...,None] + self.model.background_color * (1 - fg_mask[...,None]) + + batch.update({ + 'rays': rays, + 'rgb': rgb, + 'fg_mask': fg_mask + }) + + def training_step(self, batch, batch_idx): + out = self(batch) + + loss = 0. + + # update train_num_rays + if self.config.model.dynamic_ray_sampling: + train_num_rays = int(self.train_num_rays * (self.train_num_samples / out['num_samples_full'].sum().item())) + self.train_num_rays = min(int(self.train_num_rays * 0.9 + train_num_rays * 0.1), self.config.model.max_train_num_rays) + + loss_rgb_mse = F.mse_loss(out['comp_rgb_full'][out['rays_valid_full'][...,0]], batch['rgb'][out['rays_valid_full'][...,0]]) + self.log('train/loss_rgb_mse', loss_rgb_mse) + loss += loss_rgb_mse * self.C(self.config.system.loss.lambda_rgb_mse) + + loss_rgb_l1 = F.l1_loss(out['comp_rgb_full'][out['rays_valid_full'][...,0]], batch['rgb'][out['rays_valid_full'][...,0]]) + self.log('train/loss_rgb', loss_rgb_l1) + loss += loss_rgb_l1 * self.C(self.config.system.loss.lambda_rgb_l1) + + loss_eikonal = ((torch.linalg.norm(out['sdf_grad_samples'], ord=2, dim=-1) - 1.)**2).mean() + self.log('train/loss_eikonal', loss_eikonal) + loss += loss_eikonal * self.C(self.config.system.loss.lambda_eikonal) + + opacity = torch.clamp(out['opacity'].squeeze(-1), 1.e-3, 1.-1.e-3) + loss_mask = binary_cross_entropy(opacity, batch['fg_mask'].float()) + self.log('train/loss_mask', loss_mask) + loss += loss_mask * (self.C(self.config.system.loss.lambda_mask) if self.dataset.has_mask else 0.0) + + loss_opaque = binary_cross_entropy(opacity, opacity) + self.log('train/loss_opaque', loss_opaque) + loss += loss_opaque * self.C(self.config.system.loss.lambda_opaque) + + loss_sparsity = torch.exp(-self.config.system.loss.sparsity_scale * out['sdf_samples'].abs()).mean() + self.log('train/loss_sparsity', loss_sparsity) + loss += loss_sparsity * self.C(self.config.system.loss.lambda_sparsity) + + if self.C(self.config.system.loss.lambda_curvature) > 0: + assert 'sdf_laplace_samples' in out, "Need geometry.grad_type='finite_difference' to get SDF Laplace samples" + loss_curvature = out['sdf_laplace_samples'].abs().mean() + self.log('train/loss_curvature', loss_curvature) + loss += loss_curvature * self.C(self.config.system.loss.lambda_curvature) + + # distortion loss proposed in MipNeRF360 + # an efficient implementation from https://github.com/sunset1995/torch_efficient_distloss + if self.C(self.config.system.loss.lambda_distortion) > 0: + loss_distortion = flatten_eff_distloss(out['weights'], out['points'], out['intervals'], out['ray_indices']) + self.log('train/loss_distortion', loss_distortion) + loss += loss_distortion * self.C(self.config.system.loss.lambda_distortion) + + if self.config.model.learned_background and self.C(self.config.system.loss.lambda_distortion_bg) > 0: + loss_distortion_bg = flatten_eff_distloss(out['weights_bg'], out['points_bg'], out['intervals_bg'], out['ray_indices_bg']) + self.log('train/loss_distortion_bg', loss_distortion_bg) + loss += loss_distortion_bg * self.C(self.config.system.loss.lambda_distortion_bg) + + losses_model_reg = self.model.regularizations(out) + for name, value in losses_model_reg.items(): + self.log(f'train/loss_{name}', value) + loss_ = value * self.C(self.config.system.loss[f"lambda_{name}"]) + loss += loss_ + + self.log('train/inv_s', out['inv_s'], prog_bar=True) + + for name, value in self.config.system.loss.items(): + if name.startswith('lambda'): + self.log(f'train_params/{name}', self.C(value)) + + self.log('train/num_rays', float(self.train_num_rays), prog_bar=True) + + return { + 'loss': loss + } + + """ + # aggregate outputs from different devices (DP) + def training_step_end(self, out): + pass + """ + + """ + # aggregate outputs from different iterations + def training_epoch_end(self, out): + pass + """ + + def validation_step(self, batch, batch_idx): + out = self(batch) + psnr = self.criterions['psnr'](out['comp_rgb_full'].to(batch['rgb']), batch['rgb']) + W, H = self.dataset.img_wh + self.save_image_grid(f"it{self.global_step}-{batch['index'][0].item()}.png", [ + {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb_full'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}} + ] + ([ + {'type': 'rgb', 'img': out['comp_rgb_bg'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + ] if self.config.model.learned_background else []) + [ + {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, + {'type': 'rgb', 'img': out['comp_normal'].view(H, W, 3), 'kwargs': {'data_format': 'HWC', 'data_range': (-1, 1)}} + ]) + return { + 'psnr': psnr, + 'index': batch['index'] + } + + + """ + # aggregate outputs from different devices when using DP + def validation_step_end(self, out): + pass + """ + + def validation_epoch_end(self, out): + out = self.all_gather(out) + if self.trainer.is_global_zero: + out_set = {} + for step_out in out: + # DP + if step_out['index'].ndim == 1: + out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # DDP + else: + for oi, index in enumerate(step_out['index']): + out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) + self.log('val/psnr', psnr, prog_bar=True, rank_zero_only=True) + + def test_step(self, batch, batch_idx): + out = self(batch) + psnr = self.criterions['psnr'](out['comp_rgb_full'].to(batch['rgb']), batch['rgb']) + W, H = self.dataset.img_wh + self.save_image_grid(f"it{self.global_step}-test/{batch['index'][0].item()}.png", [ + {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb_full'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}} + ] + ([ + {'type': 'rgb', 'img': out['comp_rgb_bg'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + ] if self.config.model.learned_background else []) + [ + {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, + {'type': 'rgb', 'img': out['comp_normal'].view(H, W, 3), 'kwargs': {'data_format': 'HWC', 'data_range': (-1, 1)}} + ]) + return { + 'psnr': psnr, + 'index': batch['index'] + } + + def test_epoch_end(self, out): + """ + Synchronize devices. + Generate image sequence using test outputs. + """ + out = self.all_gather(out) + if self.trainer.is_global_zero: + out_set = {} + for step_out in out: + # DP + if step_out['index'].ndim == 1: + out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # DDP + else: + for oi, index in enumerate(step_out['index']): + out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) + self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True) + + self.save_img_sequence( + f"it{self.global_step}-test", + f"it{self.global_step}-test", + '(\d+)\.png', + save_format='mp4', + fps=30 + ) + + self.export() + + def export(self): + mesh = self.model.export(self.config.export) + self.save_mesh( + f"it{self.global_step}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj", + **mesh + ) diff --git a/instant-nsr-pl/systems/neus_ortho.py b/instant-nsr-pl/systems/neus_ortho.py new file mode 100644 index 0000000000000000000000000000000000000000..803b2a84564e491e16883ee0177979c4280e8b3d --- /dev/null +++ b/instant-nsr-pl/systems/neus_ortho.py @@ -0,0 +1,358 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_efficient_distloss import flatten_eff_distloss + +import pytorch_lightning as pl +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug + +import models +from models.utils import cleanup +from models.ray_utils import get_ortho_rays +import systems +from systems.base import BaseSystem +from systems.criterions import PSNR, binary_cross_entropy + +import pdb + +def ranking_loss(error, penalize_ratio=0.7, extra_weights=None , type='mean'): + error, indices = torch.sort(error) + # only sum relatively small errors + s_error = torch.index_select(error, 0, index=indices[:int(penalize_ratio * indices.shape[0])]) + if extra_weights is not None: + weights = torch.index_select(extra_weights, 0, index=indices[:int(penalize_ratio * indices.shape[0])]) + s_error = s_error * weights + + if type == 'mean': + return torch.mean(s_error) + elif type == 'sum': + return torch.sum(s_error) + +@systems.register('ortho-neus-system') +class OrthoNeuSSystem(BaseSystem): + """ + Two ways to print to console: + 1. self.print: correctly handle progress bar + 2. rank_zero_info: use the logging module + """ + def prepare(self): + self.criterions = { + 'psnr': PSNR() + } + self.train_num_samples = self.config.model.train_num_rays * (self.config.model.num_samples_per_ray + self.config.model.get('num_samples_per_ray_bg', 0)) + self.train_num_rays = self.config.model.train_num_rays + self.cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) + + def forward(self, batch): + return self.model(batch['rays']) + + def preprocess_data(self, batch, stage): + if 'index' in batch: # validation / testing + index = batch['index'] + else: + if self.config.model.batch_image_sampling: + index = torch.randint(0, len(self.dataset.all_images), size=(self.train_num_rays,), device=self.dataset.all_images.device) + else: + index = torch.randint(0, len(self.dataset.all_images), size=(1,), device=self.dataset.all_images.device) + if stage in ['train']: + c2w = self.dataset.all_c2w[index] + x = torch.randint( + 0, self.dataset.w, size=(self.train_num_rays,), device=self.dataset.all_images.device + ) + y = torch.randint( + 0, self.dataset.h, size=(self.train_num_rays,), device=self.dataset.all_images.device + ) + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions[y, x] + origins = self.dataset.origins[y, x] + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index, y, x] + origins = self.dataset.origins[index, y, x] + rays_o, rays_d = get_ortho_rays(origins, directions, c2w) + rgb = self.dataset.all_images[index, y, x].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + normal = self.dataset.all_normals_world[index, y, x].view(-1, self.dataset.all_normals_world.shape[-1]).to(self.rank) + fg_mask = self.dataset.all_fg_masks[index, y, x].view(-1).to(self.rank) + rgb_mask = self.dataset.all_rgb_masks[index, y, x].view(-1).to(self.rank) + view_weights = self.dataset.view_weights[index, y, x].view(-1).to(self.rank) + else: + c2w = self.dataset.all_c2w[index][0] + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions + origins = self.dataset.origins + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index][0] + origins = self.dataset.origins[index][0] + rays_o, rays_d = get_ortho_rays(origins, directions, c2w) + rgb = self.dataset.all_images[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + normal = self.dataset.all_normals_world[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank) + rgb_mask = self.dataset.all_rgb_masks[index].view(-1).to(self.rank) + view_weights = None + + cosines = self.cos(rays_d, normal) + rays = torch.cat([rays_o, F.normalize(rays_d, p=2, dim=-1)], dim=-1) + + if stage in ['train']: + if self.config.model.background_color == 'white': + self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) + elif self.config.model.background_color == 'black': + self.model.background_color = torch.zeros((3,), dtype=torch.float32, device=self.rank) + elif self.config.model.background_color == 'random': + self.model.background_color = torch.rand((3,), dtype=torch.float32, device=self.rank) + else: + raise NotImplementedError + else: + self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) + + if self.dataset.apply_mask: + rgb = rgb * fg_mask[...,None] + self.model.background_color * (1 - fg_mask[...,None]) + + batch.update({ + 'rays': rays, + 'rgb': rgb, + 'normal': normal, + 'fg_mask': fg_mask, + 'rgb_mask': rgb_mask, + 'cosines': cosines, + 'view_weights': view_weights + }) + + def training_step(self, batch, batch_idx): + out = self(batch) + + cosines = batch['cosines'] + fg_mask = batch['fg_mask'] + rgb_mask = batch['rgb_mask'] + view_weights = batch['view_weights'] + + cosines[cosines > -0.1] = 0 + mask = ((fg_mask > 0) & (cosines < -0.1)) + rgb_mask = out['rays_valid_full'][...,0] & (rgb_mask > 0) + + grad_cosines = self.cos(batch['rays'][...,3:], out['comp_normal']).detach() + # grad_cosines = cosines + + loss = 0. + + # update train_num_rays + if self.config.model.dynamic_ray_sampling: + train_num_rays = int(self.train_num_rays * (self.train_num_samples / out['num_samples_full'].sum().item())) + self.train_num_rays = min(int(self.train_num_rays * 0.9 + train_num_rays * 0.1), self.config.model.max_train_num_rays) + + erros_rgb_mse = F.mse_loss(out['comp_rgb_full'][rgb_mask], batch['rgb'][rgb_mask], reduction='none') + # erros_rgb_mse = erros_rgb_mse * torch.exp(grad_cosines.abs())[:, None][rgb_mask] / torch.exp(grad_cosines.abs()[rgb_mask]).sum() + # loss_rgb_mse = ranking_loss(erros_rgb_mse.sum(dim=1), penalize_ratio=0.7, type='sum') + loss_rgb_mse = ranking_loss(erros_rgb_mse.sum(dim=1), + penalize_ratio=self.config.system.loss.rgb_p_ratio, type='mean') + self.log('train/loss_rgb_mse', loss_rgb_mse, prog_bar=True, rank_zero_only=True) + loss += loss_rgb_mse * self.C(self.config.system.loss.lambda_rgb_mse) + + loss_rgb_l1 = F.l1_loss(out['comp_rgb_full'][rgb_mask], batch['rgb'][rgb_mask], reduction='none') + loss_rgb_l1 = ranking_loss(loss_rgb_l1.sum(dim=1), + # extra_weights=view_weights[rgb_mask], + penalize_ratio=0.8) + self.log('train/loss_rgb', loss_rgb_l1) + loss += loss_rgb_l1 * self.C(self.config.system.loss.lambda_rgb_l1) + + normal_errors = 1 - F.cosine_similarity(out['comp_normal'], batch['normal'], dim=1) + # normal_errors = normal_errors * cosines.abs() / cosines.abs().sum() + if self.config.system.loss.geo_aware: + normal_errors = normal_errors * torch.exp(cosines.abs()) / torch.exp(cosines.abs()).sum() + loss_normal = ranking_loss(normal_errors[mask], + penalize_ratio=self.config.system.loss.normal_p_ratio, + extra_weights=view_weights[mask], + type='sum') + else: + loss_normal = ranking_loss(normal_errors[mask], + penalize_ratio=self.config.system.loss.normal_p_ratio, + extra_weights=view_weights[mask], + type='mean') + + self.log('train/loss_normal', loss_normal, prog_bar=True, rank_zero_only=True) + loss += loss_normal * self.C(self.config.system.loss.lambda_normal) + + loss_eikonal = ((torch.linalg.norm(out['sdf_grad_samples'], ord=2, dim=-1) - 1.)**2).mean() + self.log('train/loss_eikonal', loss_eikonal, prog_bar=True, rank_zero_only=True) + loss += loss_eikonal * self.C(self.config.system.loss.lambda_eikonal) + + opacity = torch.clamp(out['opacity'].squeeze(-1), 1.e-3, 1.-1.e-3) + loss_mask = binary_cross_entropy(opacity, batch['fg_mask'].float(), reduction='none') + loss_mask = ranking_loss(loss_mask, + penalize_ratio=self.config.system.loss.mask_p_ratio, + extra_weights=view_weights) + self.log('train/loss_mask', loss_mask, prog_bar=True, rank_zero_only=True) + loss += loss_mask * (self.C(self.config.system.loss.lambda_mask) if self.dataset.has_mask else 0.0) + + loss_opaque = binary_cross_entropy(opacity, opacity) + self.log('train/loss_opaque', loss_opaque) + loss += loss_opaque * self.C(self.config.system.loss.lambda_opaque) + + loss_sparsity = torch.exp(-self.config.system.loss.sparsity_scale * out['random_sdf'].abs()).mean() + self.log('train/loss_sparsity', loss_sparsity, prog_bar=True, rank_zero_only=True) + loss += loss_sparsity * self.C(self.config.system.loss.lambda_sparsity) + + if self.C(self.config.system.loss.lambda_curvature) > 0: + assert 'sdf_laplace_samples' in out, "Need geometry.grad_type='finite_difference' to get SDF Laplace samples" + loss_curvature = out['sdf_laplace_samples'].abs().mean() + self.log('train/loss_curvature', loss_curvature) + loss += loss_curvature * self.C(self.config.system.loss.lambda_curvature) + + # distortion loss proposed in MipNeRF360 + # an efficient implementation from https://github.com/sunset1995/torch_efficient_distloss + if self.C(self.config.system.loss.lambda_distortion) > 0: + loss_distortion = flatten_eff_distloss(out['weights'], out['points'], out['intervals'], out['ray_indices']) + self.log('train/loss_distortion', loss_distortion) + loss += loss_distortion * self.C(self.config.system.loss.lambda_distortion) + + if self.config.model.learned_background and self.C(self.config.system.loss.lambda_distortion_bg) > 0: + loss_distortion_bg = flatten_eff_distloss(out['weights_bg'], out['points_bg'], out['intervals_bg'], out['ray_indices_bg']) + self.log('train/loss_distortion_bg', loss_distortion_bg) + loss += loss_distortion_bg * self.C(self.config.system.loss.lambda_distortion_bg) + + if self.C(self.config.system.loss.lambda_3d_normal_smooth) > 0: + if "random_sdf_grad" not in out: + raise ValueError( + "random_sdf_grad is required for normal smooth loss, no normal is found in the output." + ) + if "normal_perturb" not in out: + raise ValueError( + "normal_perturb is required for normal smooth loss, no normal_perturb is found in the output." + ) + normals_3d = out["random_sdf_grad"] + normals_perturb_3d = out["normal_perturb"] + loss_3d_normal_smooth = (normals_3d - normals_perturb_3d).abs().mean() + self.log('train/loss_3d_normal_smooth', loss_3d_normal_smooth, prog_bar=True ) + + loss += loss_3d_normal_smooth * self.C(self.config.system.loss.lambda_3d_normal_smooth) + + losses_model_reg = self.model.regularizations(out) + for name, value in losses_model_reg.items(): + self.log(f'train/loss_{name}', value) + loss_ = value * self.C(self.config.system.loss[f"lambda_{name}"]) + loss += loss_ + + self.log('train/inv_s', out['inv_s'], prog_bar=True) + + for name, value in self.config.system.loss.items(): + if name.startswith('lambda'): + self.log(f'train_params/{name}', self.C(value)) + + self.log('train/num_rays', float(self.train_num_rays), prog_bar=True) + + return { + 'loss': loss + } + + """ + # aggregate outputs from different devices (DP) + def training_step_end(self, out): + pass + """ + + """ + # aggregate outputs from different iterations + def training_epoch_end(self, out): + pass + """ + + def validation_step(self, batch, batch_idx): + out = self(batch) + psnr = self.criterions['psnr'](out['comp_rgb_full'].to(batch['rgb']), batch['rgb']) + W, H = self.dataset.img_wh + self.save_image_grid(f"it{self.global_step}-{batch['index'][0].item()}.png", [ + {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb_full'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}} + ] + ([ + {'type': 'rgb', 'img': out['comp_rgb_bg'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + ] if self.config.model.learned_background else []) + [ + {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, + {'type': 'rgb', 'img': out['comp_normal'].view(H, W, 3), 'kwargs': {'data_format': 'HWC', 'data_range': (-1, 1)}} + ]) + return { + 'psnr': psnr, + 'index': batch['index'] + } + + + """ + # aggregate outputs from different devices when using DP + def validation_step_end(self, out): + pass + """ + + def validation_epoch_end(self, out): + out = self.all_gather(out) + if self.trainer.is_global_zero: + out_set = {} + for step_out in out: + # DP + if step_out['index'].ndim == 1: + out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # DDP + else: + for oi, index in enumerate(step_out['index']): + out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) + self.log('val/psnr', psnr, prog_bar=True, rank_zero_only=True) + self.export() + + # def test_step(self, batch, batch_idx): + # out = self(batch) + # psnr = self.criterions['psnr'](out['comp_rgb_full'].to(batch['rgb']), batch['rgb']) + # W, H = self.dataset.img_wh + # self.save_image_grid(f"it{self.global_step}-test/{batch['index'][0].item()}.png", [ + # {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + # {'type': 'rgb', 'img': out['comp_rgb_full'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}} + # ] + ([ + # {'type': 'rgb', 'img': out['comp_rgb_bg'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + # {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + # ] if self.config.model.learned_background else []) + [ + # {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, + # {'type': 'rgb', 'img': out['comp_normal'].view(H, W, 3), 'kwargs': {'data_format': 'HWC', 'data_range': (-1, 1)}} + # ]) + # return { + # 'psnr': psnr, + # 'index': batch['index'] + # } + + def test_step(self, batch, batch_idx): + pass + + def test_epoch_end(self, out): + """ + Synchronize devices. + Generate image sequence using test outputs. + """ + # out = self.all_gather(out) + if self.trainer.is_global_zero: + # out_set = {} + # for step_out in out: + # # DP + # if step_out['index'].ndim == 1: + # out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # # DDP + # else: + # for oi, index in enumerate(step_out['index']): + # out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + # psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) + # self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True) + + # self.save_img_sequence( + # f"it{self.global_step}-test", + # f"it{self.global_step}-test", + # '(\d+)\.png', + # save_format='mp4', + # fps=30 + # ) + + self.export() + + def export(self): + mesh = self.model.export(self.config.export) + # pdb.set_trace() + self.save_mesh( + f"it{self.global_step}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj", + ortho_scale=self.config.export.ortho_scale, + **mesh + ) diff --git a/instant-nsr-pl/systems/neus_pinhole.py b/instant-nsr-pl/systems/neus_pinhole.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc224e0c9bc9e45f9c364be804e7353927ba1d1 --- /dev/null +++ b/instant-nsr-pl/systems/neus_pinhole.py @@ -0,0 +1,343 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_efficient_distloss import flatten_eff_distloss + +import pytorch_lightning as pl +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug + +import models +from models.utils import cleanup +from models.ray_utils import get_rays +import systems +from systems.base import BaseSystem +from systems.criterions import PSNR, binary_cross_entropy + +import pdb + +def ranking_loss(error, penalize_ratio=0.7, extra_weights=None , type='mean'): + error, indices = torch.sort(error) + # only sum relatively small errors + s_error = torch.index_select(error, 0, index=indices[:int(penalize_ratio * indices.shape[0])]) + if extra_weights is not None: + weights = torch.index_select(extra_weights, 0, index=indices[:int(penalize_ratio * indices.shape[0])]) + s_error = s_error * weights + + if type == 'mean': + return torch.mean(s_error) + elif type == 'sum': + return torch.sum(s_error) + +@systems.register('pinhole-neus-system') +class PinholeNeuSSystem(BaseSystem): + """ + Two ways to print to console: + 1. self.print: correctly handle progress bar + 2. rank_zero_info: use the logging module + """ + def prepare(self): + self.criterions = { + 'psnr': PSNR() + } + self.train_num_samples = self.config.model.train_num_rays * (self.config.model.num_samples_per_ray + self.config.model.get('num_samples_per_ray_bg', 0)) + self.train_num_rays = self.config.model.train_num_rays + self.cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) + + def forward(self, batch): + return self.model(batch['rays']) + + def preprocess_data(self, batch, stage): + if 'index' in batch: # validation / testing + index = batch['index'] + else: + if self.config.model.batch_image_sampling: + index = torch.randint(0, len(self.dataset.all_images), size=(self.train_num_rays,), device=self.dataset.all_images.device) + else: + index = torch.randint(0, len(self.dataset.all_images), size=(1,), device=self.dataset.all_images.device) + if stage in ['train']: + c2w = self.dataset.all_c2w[index] + x = torch.randint( + 0, self.dataset.w, size=(self.train_num_rays,), device=self.dataset.all_images.device + ) + y = torch.randint( + 0, self.dataset.h, size=(self.train_num_rays,), device=self.dataset.all_images.device + ) + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions[y, x] + # origins = self.dataset.origins[y, x] + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index, y, x] + # origins = self.dataset.origins[index, y, x] + rays_o, rays_d = get_rays(directions, c2w) + rgb = self.dataset.all_images[index, y, x].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + normal = self.dataset.all_normals_world[index, y, x].view(-1, self.dataset.all_normals_world.shape[-1]).to(self.rank) + fg_mask = self.dataset.all_fg_masks[index, y, x].view(-1).to(self.rank) + rgb_mask = self.dataset.all_rgb_masks[index, y, x].view(-1).to(self.rank) + view_weights = self.dataset.view_weights[index, y, x].view(-1).to(self.rank) + else: + c2w = self.dataset.all_c2w[index][0] + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions + # origins = self.dataset.origins + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index][0] + # origins = self.dataset.origins[index][0] + rays_o, rays_d = get_rays(directions, c2w) + rgb = self.dataset.all_images[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + normal = self.dataset.all_normals_world[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank) + rgb_mask = self.dataset.all_rgb_masks[index].view(-1).to(self.rank) + view_weights = None + + cosines = self.cos(rays_d, normal) + rays = torch.cat([rays_o, F.normalize(rays_d, p=2, dim=-1)], dim=-1) + + if stage in ['train']: + if self.config.model.background_color == 'white': + self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) + elif self.config.model.background_color == 'black': + self.model.background_color = torch.zeros((3,), dtype=torch.float32, device=self.rank) + elif self.config.model.background_color == 'random': + self.model.background_color = torch.rand((3,), dtype=torch.float32, device=self.rank) + else: + raise NotImplementedError + else: + self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) + + if self.dataset.apply_mask: + rgb = rgb * fg_mask[...,None] + self.model.background_color * (1 - fg_mask[...,None]) + + batch.update({ + 'rays': rays, + 'rgb': rgb, + 'normal': normal, + 'fg_mask': fg_mask, + 'rgb_mask': rgb_mask, + 'cosines': cosines, + 'view_weights': view_weights + }) + + def training_step(self, batch, batch_idx): + out = self(batch) + + cosines = batch['cosines'] + fg_mask = batch['fg_mask'] + rgb_mask = batch['rgb_mask'] + view_weights = batch['view_weights'] + + cosines[cosines > -0.1] = 0 + mask = ((fg_mask > 0) & (cosines < -0.1)) + rgb_mask = out['rays_valid_full'][...,0] & (rgb_mask > 0) + + grad_cosines = self.cos(batch['rays'][...,3:], out['comp_normal']).detach() + # grad_cosines = cosines + + loss = 0. + + # update train_num_rays + if self.config.model.dynamic_ray_sampling: + train_num_rays = int(self.train_num_rays * (self.train_num_samples / out['num_samples_full'].sum().item())) + self.train_num_rays = min(int(self.train_num_rays * 0.9 + train_num_rays * 0.1), self.config.model.max_train_num_rays) + + erros_rgb_mse = F.mse_loss(out['comp_rgb_full'][rgb_mask], batch['rgb'][rgb_mask], reduction='none') + # erros_rgb_mse = erros_rgb_mse * torch.exp(grad_cosines.abs())[:, None][rgb_mask] / torch.exp(grad_cosines.abs()[rgb_mask]).sum() + # loss_rgb_mse = ranking_loss(erros_rgb_mse.sum(dim=1), penalize_ratio=0.7, type='sum') + loss_rgb_mse = ranking_loss(erros_rgb_mse.sum(dim=1), penalize_ratio=0.7, type='mean') + self.log('train/loss_rgb_mse', loss_rgb_mse, prog_bar=True, rank_zero_only=True) + loss += loss_rgb_mse * self.C(self.config.system.loss.lambda_rgb_mse) + + loss_rgb_l1 = F.l1_loss(out['comp_rgb_full'][rgb_mask], batch['rgb'][rgb_mask], reduction='none') + loss_rgb_l1 = ranking_loss(loss_rgb_l1.sum(dim=1), + extra_weights=view_weights[rgb_mask], + penalize_ratio=0.8) + self.log('train/loss_rgb', loss_rgb_l1) + loss += loss_rgb_l1 * self.C(self.config.system.loss.lambda_rgb_l1) + + normal_errors = 1 - F.cosine_similarity(out['comp_normal'], batch['normal'], dim=1) + # normal_errors = normal_errors * cosines.abs() / cosines.abs().sum() + normal_errors = normal_errors * torch.exp(cosines.abs()) / torch.exp(cosines.abs()).sum() + loss_normal = ranking_loss(normal_errors[mask], penalize_ratio=0.8, + # extra_weights=view_weights[mask], + type='sum') + self.log('train/loss_normal', loss_normal, prog_bar=True, rank_zero_only=True) + loss += loss_normal * self.C(self.config.system.loss.lambda_normal) + + loss_eikonal = ((torch.linalg.norm(out['sdf_grad_samples'], ord=2, dim=-1) - 1.)**2).mean() + self.log('train/loss_eikonal', loss_eikonal, prog_bar=True, rank_zero_only=True) + loss += loss_eikonal * self.C(self.config.system.loss.lambda_eikonal) + + opacity = torch.clamp(out['opacity'].squeeze(-1), 1.e-3, 1.-1.e-3) + loss_mask = binary_cross_entropy(opacity, batch['fg_mask'].float(), reduction='none') + loss_mask = ranking_loss(loss_mask, penalize_ratio=0.9, extra_weights=view_weights) + self.log('train/loss_mask', loss_mask, prog_bar=True, rank_zero_only=True) + loss += loss_mask * (self.C(self.config.system.loss.lambda_mask) if self.dataset.has_mask else 0.0) + + loss_opaque = binary_cross_entropy(opacity, opacity) + self.log('train/loss_opaque', loss_opaque) + loss += loss_opaque * self.C(self.config.system.loss.lambda_opaque) + + loss_sparsity = torch.exp(-self.config.system.loss.sparsity_scale * out['random_sdf'].abs()).mean() + self.log('train/loss_sparsity', loss_sparsity, prog_bar=True, rank_zero_only=True) + loss += loss_sparsity * self.C(self.config.system.loss.lambda_sparsity) + + if self.C(self.config.system.loss.lambda_curvature) > 0: + assert 'sdf_laplace_samples' in out, "Need geometry.grad_type='finite_difference' to get SDF Laplace samples" + loss_curvature = out['sdf_laplace_samples'].abs().mean() + self.log('train/loss_curvature', loss_curvature) + loss += loss_curvature * self.C(self.config.system.loss.lambda_curvature) + + # distortion loss proposed in MipNeRF360 + # an efficient implementation from https://github.com/sunset1995/torch_efficient_distloss + if self.C(self.config.system.loss.lambda_distortion) > 0: + loss_distortion = flatten_eff_distloss(out['weights'], out['points'], out['intervals'], out['ray_indices']) + self.log('train/loss_distortion', loss_distortion) + loss += loss_distortion * self.C(self.config.system.loss.lambda_distortion) + + if self.config.model.learned_background and self.C(self.config.system.loss.lambda_distortion_bg) > 0: + loss_distortion_bg = flatten_eff_distloss(out['weights_bg'], out['points_bg'], out['intervals_bg'], out['ray_indices_bg']) + self.log('train/loss_distortion_bg', loss_distortion_bg) + loss += loss_distortion_bg * self.C(self.config.system.loss.lambda_distortion_bg) + + if self.C(self.config.system.loss.lambda_3d_normal_smooth) > 0: + if "random_sdf_grad" not in out: + raise ValueError( + "random_sdf_grad is required for normal smooth loss, no normal is found in the output." + ) + if "normal_perturb" not in out: + raise ValueError( + "normal_perturb is required for normal smooth loss, no normal_perturb is found in the output." + ) + normals_3d = out["random_sdf_grad"] + normals_perturb_3d = out["normal_perturb"] + loss_3d_normal_smooth = (normals_3d - normals_perturb_3d).abs().mean() + self.log('train/loss_3d_normal_smooth', loss_3d_normal_smooth, prog_bar=True ) + + loss += loss_3d_normal_smooth * self.C(self.config.system.loss.lambda_3d_normal_smooth) + + losses_model_reg = self.model.regularizations(out) + for name, value in losses_model_reg.items(): + self.log(f'train/loss_{name}', value) + loss_ = value * self.C(self.config.system.loss[f"lambda_{name}"]) + loss += loss_ + + self.log('train/inv_s', out['inv_s'], prog_bar=True) + + for name, value in self.config.system.loss.items(): + if name.startswith('lambda'): + self.log(f'train_params/{name}', self.C(value)) + + self.log('train/num_rays', float(self.train_num_rays), prog_bar=True) + + return { + 'loss': loss + } + + """ + # aggregate outputs from different devices (DP) + def training_step_end(self, out): + pass + """ + + """ + # aggregate outputs from different iterations + def training_epoch_end(self, out): + pass + """ + + def validation_step(self, batch, batch_idx): + out = self(batch) + psnr = self.criterions['psnr'](out['comp_rgb_full'].to(batch['rgb']), batch['rgb']) + W, H = self.dataset.img_wh + self.save_image_grid(f"it{self.global_step}-{batch['index'][0].item()}.png", [ + {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb_full'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}} + ] + ([ + {'type': 'rgb', 'img': out['comp_rgb_bg'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + ] if self.config.model.learned_background else []) + [ + {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, + {'type': 'rgb', 'img': out['comp_normal'].view(H, W, 3), 'kwargs': {'data_format': 'HWC', 'data_range': (-1, 1)}} + ]) + return { + 'psnr': psnr, + 'index': batch['index'] + } + + + """ + # aggregate outputs from different devices when using DP + def validation_step_end(self, out): + pass + """ + + def validation_epoch_end(self, out): + out = self.all_gather(out) + if self.trainer.is_global_zero: + out_set = {} + for step_out in out: + # DP + if step_out['index'].ndim == 1: + out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # DDP + else: + for oi, index in enumerate(step_out['index']): + out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) + self.log('val/psnr', psnr, prog_bar=True, rank_zero_only=True) + self.export() + + def test_step(self, batch, batch_idx): + out = self(batch) + psnr = self.criterions['psnr'](out['comp_rgb_full'].to(batch['rgb']), batch['rgb']) + W, H = self.dataset.img_wh + self.save_image_grid(f"it{self.global_step}-test/{batch['index'][0].item()}.png", [ + {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb_full'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}} + ] + ([ + {'type': 'rgb', 'img': out['comp_rgb_bg'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + ] if self.config.model.learned_background else []) + [ + {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, + {'type': 'rgb', 'img': out['comp_normal'].view(H, W, 3), 'kwargs': {'data_format': 'HWC', 'data_range': (-1, 1)}} + ]) + return { + 'psnr': psnr, + 'index': batch['index'] + } + + def test_epoch_end(self, out): + """ + Synchronize devices. + Generate image sequence using test outputs. + """ + out = self.all_gather(out) + if self.trainer.is_global_zero: + out_set = {} + for step_out in out: + # DP + if step_out['index'].ndim == 1: + out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # DDP + else: + for oi, index in enumerate(step_out['index']): + out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) + self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True) + + self.save_img_sequence( + f"it{self.global_step}-test", + f"it{self.global_step}-test", + '(\d+)\.png', + save_format='mp4', + fps=30 + ) + + self.export() + + def export(self): + mesh = self.model.export(self.config.export) + self.save_mesh( + f"it{self.global_step}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj", + ortho_scale=self.config.export.ortho_scale, + **mesh + ) diff --git a/instant-nsr-pl/systems/utils.py b/instant-nsr-pl/systems/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dafae78295305113fd1854e9104bf44be24f4727 --- /dev/null +++ b/instant-nsr-pl/systems/utils.py @@ -0,0 +1,351 @@ +import sys +import warnings +from bisect import bisect_right + +import torch +import torch.nn as nn +from torch.optim import lr_scheduler + +from pytorch_lightning.utilities.rank_zero import rank_zero_debug + + +class ChainedScheduler(lr_scheduler._LRScheduler): + """Chains list of learning rate schedulers. It takes a list of chainable learning + rate schedulers and performs consecutive step() functions belong to them by just + one call. + + Args: + schedulers (list): List of chained schedulers. + + Example: + >>> # Assuming optimizer uses lr = 1. for all groups + >>> # lr = 0.09 if epoch == 0 + >>> # lr = 0.081 if epoch == 1 + >>> # lr = 0.729 if epoch == 2 + >>> # lr = 0.6561 if epoch == 3 + >>> # lr = 0.59049 if epoch >= 4 + >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) + >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) + >>> scheduler = ChainedScheduler([scheduler1, scheduler2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, schedulers): + for scheduler_idx in range(1, len(schedulers)): + if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): + raise ValueError( + "ChainedScheduler expects all schedulers to belong to the same optimizer, but " + "got schedulers at index {} and {} to be different".format(0, scheduler_idx) + ) + self._schedulers = list(schedulers) + self.optimizer = optimizer + + def step(self): + for scheduler in self._schedulers: + scheduler.step() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The wrapped scheduler states will also be saved. + """ + state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} + state_dict['_schedulers'] = [None] * len(self._schedulers) + + for idx, s in enumerate(self._schedulers): + state_dict['_schedulers'][idx] = s.state_dict() + + return state_dict + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + _schedulers = state_dict.pop('_schedulers') + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict['_schedulers'] = _schedulers + + for idx, s in enumerate(_schedulers): + self._schedulers[idx].load_state_dict(s) + + +class SequentialLR(lr_scheduler._LRScheduler): + """Receives the list of schedulers that is expected to be called sequentially during + optimization process and milestone points that provides exact intervals to reflect + which scheduler is supposed to be called at a given epoch. + + Args: + schedulers (list): List of chained schedulers. + milestones (list): List of integers that reflects milestone points. + + Example: + >>> # Assuming optimizer uses lr = 1. for all groups + >>> # lr = 0.1 if epoch == 0 + >>> # lr = 0.1 if epoch == 1 + >>> # lr = 0.9 if epoch == 2 + >>> # lr = 0.81 if epoch == 3 + >>> # lr = 0.729 if epoch == 4 + >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) + >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) + >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False): + for scheduler_idx in range(1, len(schedulers)): + if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): + raise ValueError( + "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " + "got schedulers at index {} and {} to be different".format(0, scheduler_idx) + ) + if (len(milestones) != len(schedulers) - 1): + raise ValueError( + "Sequential Schedulers expects number of schedulers provided to be one more " + "than the number of milestone points, but got number of schedulers {} and the " + "number of milestones to be equal to {}".format(len(schedulers), len(milestones)) + ) + self._schedulers = schedulers + self._milestones = milestones + self.last_epoch = last_epoch + 1 + self.optimizer = optimizer + + def step(self): + self.last_epoch += 1 + idx = bisect_right(self._milestones, self.last_epoch) + if idx > 0 and self._milestones[idx - 1] == self.last_epoch: + self._schedulers[idx].step(0) + else: + self._schedulers[idx].step() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The wrapped scheduler states will also be saved. + """ + state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} + state_dict['_schedulers'] = [None] * len(self._schedulers) + + for idx, s in enumerate(self._schedulers): + state_dict['_schedulers'][idx] = s.state_dict() + + return state_dict + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + _schedulers = state_dict.pop('_schedulers') + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict['_schedulers'] = _schedulers + + for idx, s in enumerate(_schedulers): + self._schedulers[idx].load_state_dict(s) + + +class ConstantLR(lr_scheduler._LRScheduler): + """Decays the learning rate of each parameter group by a small constant factor until the + number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can + happen simultaneously with other changes to the learning rate from outside this scheduler. + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + factor (float): The number we multiply learning rate until the milestone. Default: 1./3. + total_iters (int): The number of steps that the scheduler decays the learning rate. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.025 if epoch == 1 + >>> # lr = 0.025 if epoch == 2 + >>> # lr = 0.025 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False): + if factor > 1.0 or factor < 0: + raise ValueError('Constant multiplicative factor expected to be between 0 and 1.') + + self.factor = factor + self.total_iters = total_iters + super(ConstantLR, self).__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0: + return [group['lr'] * self.factor for group in self.optimizer.param_groups] + + if (self.last_epoch > self.total_iters or + (self.last_epoch != self.total_iters)): + return [group['lr'] for group in self.optimizer.param_groups] + + if (self.last_epoch == self.total_iters): + return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) + for base_lr in self.base_lrs] + + +class LinearLR(lr_scheduler._LRScheduler): + """Decays the learning rate of each parameter group by linearly changing small + multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters. + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + start_factor (float): The number we multiply learning rate in the first epoch. + The multiplication factor changes towards end_factor in the following epochs. + Default: 1./3. + end_factor (float): The number we multiply learning rate at the end of linear changing + process. Default: 1.0. + total_iters (int): The number of iterations that multiplicative factor reaches to 1. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.03125 if epoch == 1 + >>> # lr = 0.0375 if epoch == 2 + >>> # lr = 0.04375 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1, + verbose=False): + if start_factor > 1.0 or start_factor < 0: + raise ValueError('Starting multiplicative factor expected to be between 0 and 1.') + + if end_factor > 1.0 or end_factor < 0: + raise ValueError('Ending multiplicative factor expected to be between 0 and 1.') + + self.start_factor = start_factor + self.end_factor = end_factor + self.total_iters = total_iters + super(LinearLR, self).__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0: + return [group['lr'] * self.start_factor for group in self.optimizer.param_groups] + + if (self.last_epoch > self.total_iters): + return [group['lr'] for group in self.optimizer.param_groups] + + return [group['lr'] * (1. + (self.end_factor - self.start_factor) / + (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor))) + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * (self.start_factor + + (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters) + for base_lr in self.base_lrs] + + +custom_schedulers = ['ConstantLR', 'LinearLR'] +def get_scheduler(name): + if hasattr(lr_scheduler, name): + return getattr(lr_scheduler, name) + elif name in custom_schedulers: + return getattr(sys.modules[__name__], name) + else: + raise NotImplementedError + + +def getattr_recursive(m, attr): + for name in attr.split('.'): + m = getattr(m, name) + return m + + +def get_parameters(model, name): + module = getattr_recursive(model, name) + if isinstance(module, nn.Module): + return module.parameters() + elif isinstance(module, nn.Parameter): + return module + return [] + + +def parse_optimizer(config, model): + if hasattr(config, 'params'): + params = [{'params': get_parameters(model, name), 'name': name, **args} for name, args in config.params.items()] + rank_zero_debug('Specify optimizer params:', config.params) + else: + params = model.parameters() + if config.name in ['FusedAdam']: + import apex + optim = getattr(apex.optimizers, config.name)(params, **config.args) + else: + optim = getattr(torch.optim, config.name)(params, **config.args) + return optim + + +def parse_scheduler(config, optimizer): + interval = config.get('interval', 'epoch') + assert interval in ['epoch', 'step'] + if config.name == 'SequentialLR': + scheduler = { + 'scheduler': SequentialLR(optimizer, [parse_scheduler(conf, optimizer)['scheduler'] for conf in config.schedulers], milestones=config.milestones), + 'interval': interval + } + elif config.name == 'Chained': + scheduler = { + 'scheduler': ChainedScheduler([parse_scheduler(conf, optimizer)['scheduler'] for conf in config.schedulers]), + 'interval': interval + } + else: + scheduler = { + 'scheduler': get_scheduler(config.name)(optimizer, **config.args), + 'interval': interval + } + return scheduler + + +def update_module_step(m, epoch, global_step): + if hasattr(m, 'update_step'): + m.update_step(epoch, global_step) diff --git a/instant-nsr-pl/utils/__init__.py b/instant-nsr-pl/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/instant-nsr-pl/utils/callbacks.py b/instant-nsr-pl/utils/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..22f39efdb2f381ff677f5311c0586fbad88ae34f --- /dev/null +++ b/instant-nsr-pl/utils/callbacks.py @@ -0,0 +1,99 @@ +import os +import subprocess +import shutil +from utils.misc import dump_config, parse_version + + +import pytorch_lightning +if parse_version(pytorch_lightning.__version__) > parse_version('1.8'): + from pytorch_lightning.callbacks import Callback +else: + from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn +from pytorch_lightning.callbacks.progress import TQDMProgressBar + + +class VersionedCallback(Callback): + def __init__(self, save_root, version=None, use_version=True): + self.save_root = save_root + self._version = version + self.use_version = use_version + + @property + def version(self) -> int: + """Get the experiment version. + + Returns: + The experiment version if specified else the next version. + """ + if self._version is None: + self._version = self._get_next_version() + return self._version + + def _get_next_version(self): + existing_versions = [] + if os.path.isdir(self.save_root): + for f in os.listdir(self.save_root): + bn = os.path.basename(f) + if bn.startswith("version_"): + dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "") + existing_versions.append(int(dir_ver)) + if len(existing_versions) == 0: + return 0 + return max(existing_versions) + 1 + + @property + def savedir(self): + if not self.use_version: + return self.save_root + return os.path.join(self.save_root, self.version if isinstance(self.version, str) else f"version_{self.version}") + + +class CodeSnapshotCallback(VersionedCallback): + def __init__(self, save_root, version=None, use_version=True): + super().__init__(save_root, version, use_version) + + def get_file_list(self): + return [ + b.decode() for b in + set(subprocess.check_output('git ls-files', shell=True).splitlines()) | + set(subprocess.check_output('git ls-files --others --exclude-standard', shell=True).splitlines()) + ] + + @rank_zero_only + def save_code_snapshot(self): + os.makedirs(self.savedir, exist_ok=True) + for f in self.get_file_list(): + if not os.path.exists(f) or os.path.isdir(f): + continue + os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) + shutil.copyfile(f, os.path.join(self.savedir, f)) + + def on_fit_start(self, trainer, pl_module): + try: + self.save_code_snapshot() + except: + rank_zero_warn("Code snapshot is not saved. Please make sure you have git installed and are in a git repository.") + + +class ConfigSnapshotCallback(VersionedCallback): + def __init__(self, config, save_root, version=None, use_version=True): + super().__init__(save_root, version, use_version) + self.config = config + + @rank_zero_only + def save_config_snapshot(self): + os.makedirs(self.savedir, exist_ok=True) + dump_config(os.path.join(self.savedir, 'parsed.yaml'), self.config) + shutil.copyfile(self.config.cmd_args['config'], os.path.join(self.savedir, 'raw.yaml')) + + def on_fit_start(self, trainer, pl_module): + self.save_config_snapshot() + + +class CustomProgressBar(TQDMProgressBar): + def get_metrics(self, *args, **kwargs): + # don't show the version number + items = super().get_metrics(*args, **kwargs) + items.pop("v_num", None) + return items diff --git a/instant-nsr-pl/utils/loggers.py b/instant-nsr-pl/utils/loggers.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1a92302a431a75e0c920327208ab11e9559ec8 --- /dev/null +++ b/instant-nsr-pl/utils/loggers.py @@ -0,0 +1,41 @@ +import re +import pprint +import logging + +from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment +from pytorch_lightning.utilities.rank_zero import rank_zero_only + + +class ConsoleLogger(LightningLoggerBase): + def __init__(self, log_keys=[]): + super().__init__() + self.log_keys = [re.compile(k) for k in log_keys] + self.dict_printer = pprint.PrettyPrinter(indent=2, compact=False).pformat + + def match_log_keys(self, s): + return True if not self.log_keys else any(r.search(s) for r in self.log_keys) + + @property + def name(self): + return 'console' + + @property + def version(self): + return '0' + + @property + @rank_zero_experiment + def experiment(self): + return logging.getLogger('pytorch_lightning') + + @rank_zero_only + def log_hyperparams(self, params): + pass + + @rank_zero_only + def log_metrics(self, metrics, step): + metrics_ = {k: v for k, v in metrics.items() if self.match_log_keys(k)} + if not metrics_: + return + self.experiment.info(f"\nEpoch{metrics['epoch']} Step{step}\n{self.dict_printer(metrics_)}") + diff --git a/instant-nsr-pl/utils/misc.py b/instant-nsr-pl/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..c16fafa2ab8e7b934be711c41aed6e12001444fd --- /dev/null +++ b/instant-nsr-pl/utils/misc.py @@ -0,0 +1,54 @@ +import os +from omegaconf import OmegaConf +from packaging import version + + +# ============ Register OmegaConf Recolvers ============= # +OmegaConf.register_new_resolver('calc_exp_lr_decay_rate', lambda factor, n: factor**(1./n)) +OmegaConf.register_new_resolver('add', lambda a, b: a + b) +OmegaConf.register_new_resolver('sub', lambda a, b: a - b) +OmegaConf.register_new_resolver('mul', lambda a, b: a * b) +OmegaConf.register_new_resolver('div', lambda a, b: a / b) +OmegaConf.register_new_resolver('idiv', lambda a, b: a // b) +OmegaConf.register_new_resolver('basename', lambda p: os.path.basename(p)) +# ======================================================= # + + +def prompt(question): + inp = input(f"{question} (y/n)").lower().strip() + if inp and inp == 'y': + return True + if inp and inp == 'n': + return False + return prompt(question) + + +def load_config(*yaml_files, cli_args=[]): + yaml_confs = [OmegaConf.load(f) for f in yaml_files] + cli_conf = OmegaConf.from_cli(cli_args) + conf = OmegaConf.merge(*yaml_confs, cli_conf) + OmegaConf.resolve(conf) + return conf + + +def config_to_primitive(config, resolve=True): + return OmegaConf.to_container(config, resolve=resolve) + + +def dump_config(path, config): + with open(path, 'w') as fp: + OmegaConf.save(config=config, f=fp) + +def get_rank(): + # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, + # therefore LOCAL_RANK needs to be checked first + rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") + for key in rank_keys: + rank = os.environ.get(key) + if rank is not None: + return int(rank) + return 0 + + +def parse_version(ver): + return version.parse(ver) diff --git a/instant-nsr-pl/utils/mixins.py b/instant-nsr-pl/utils/mixins.py new file mode 100644 index 0000000000000000000000000000000000000000..b556cebc2cad678f89cb6aeb1c08bd6f2b2df920 --- /dev/null +++ b/instant-nsr-pl/utils/mixins.py @@ -0,0 +1,264 @@ +import os +import re +import shutil +import numpy as np +import cv2 +import imageio +from matplotlib import cm +from matplotlib.colors import LinearSegmentedColormap +import json + +import torch + +from utils.obj import write_obj + + +class SaverMixin(): + @property + def save_dir(self): + return self.config.save_dir + + def convert_data(self, data): + if isinstance(data, np.ndarray): + return data + elif isinstance(data, torch.Tensor): + return data.cpu().numpy() + elif isinstance(data, list): + return [self.convert_data(d) for d in data] + elif isinstance(data, dict): + return {k: self.convert_data(v) for k, v in data.items()} + else: + raise TypeError('Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting', type(data)) + + def get_save_path(self, filename): + save_path = os.path.join(self.save_dir, filename) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + return save_path + + DEFAULT_RGB_KWARGS = {'data_format': 'CHW', 'data_range': (0, 1)} + DEFAULT_UV_KWARGS = {'data_format': 'CHW', 'data_range': (0, 1), 'cmap': 'checkerboard'} + DEFAULT_GRAYSCALE_KWARGS = {'data_range': None, 'cmap': 'jet'} + + def get_rgb_image_(self, img, data_format, data_range): + img = self.convert_data(img) + assert data_format in ['CHW', 'HWC'] + if data_format == 'CHW': + img = img.transpose(1, 2, 0) + img = img.clip(min=data_range[0], max=data_range[1]) + img = ((img - data_range[0]) / (data_range[1] - data_range[0]) * 255.).astype(np.uint8) + imgs = [img[...,start:start+3] for start in range(0, img.shape[-1], 3)] + imgs = [img_ if img_.shape[-1] == 3 else np.concatenate([img_, np.zeros((img_.shape[0], img_.shape[1], 3 - img_.shape[2]), dtype=img_.dtype)], axis=-1) for img_ in imgs] + img = np.concatenate(imgs, axis=1) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def save_rgb_image(self, filename, img, data_format=DEFAULT_RGB_KWARGS['data_format'], data_range=DEFAULT_RGB_KWARGS['data_range']): + img = self.get_rgb_image_(img, data_format, data_range) + cv2.imwrite(self.get_save_path(filename), img) + + def get_uv_image_(self, img, data_format, data_range, cmap): + img = self.convert_data(img) + assert data_format in ['CHW', 'HWC'] + if data_format == 'CHW': + img = img.transpose(1, 2, 0) + img = img.clip(min=data_range[0], max=data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in ['checkerboard', 'color'] + if cmap == 'checkerboard': + n_grid = 64 + mask = (img * n_grid).astype(int) + mask = (mask[...,0] + mask[...,1]) % 2 == 0 + img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255 + img[mask] = np.array([255, 0, 255], dtype=np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif cmap == 'color': + img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) + img_[..., 0] = (img[..., 0] * 255).astype(np.uint8) + img_[..., 1] = (img[..., 1] * 255).astype(np.uint8) + img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR) + img = img_ + return img + + def save_uv_image(self, filename, img, data_format=DEFAULT_UV_KWARGS['data_format'], data_range=DEFAULT_UV_KWARGS['data_range'], cmap=DEFAULT_UV_KWARGS['cmap']): + img = self.get_uv_image_(img, data_format, data_range, cmap) + cv2.imwrite(self.get_save_path(filename), img) + + def get_grayscale_image_(self, img, data_range, cmap): + img = self.convert_data(img) + img = np.nan_to_num(img) + if data_range is None: + img = (img - img.min()) / (img.max() - img.min()) + else: + img = img.clip(data_range[0], data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in [None, 'jet', 'magma'] + if cmap == None: + img = (img * 255.).astype(np.uint8) + img = np.repeat(img[...,None], 3, axis=2) + elif cmap == 'jet': + img = (img * 255.).astype(np.uint8) + img = cv2.applyColorMap(img, cv2.COLORMAP_JET) + elif cmap == 'magma': + img = 1. - img + base = cm.get_cmap('magma') + num_bins = 256 + colormap = LinearSegmentedColormap.from_list( + f"{base.name}{num_bins}", + base(np.linspace(0, 1, num_bins)), + num_bins + )(np.linspace(0, 1, num_bins))[:,:3] + a = np.floor(img * 255.) + b = (a + 1).clip(max=255.) + f = img * 255. - a + a = a.astype(np.uint16).clip(0, 255) + b = b.astype(np.uint16).clip(0, 255) + img = colormap[a] + (colormap[b] - colormap[a]) * f[...,None] + img = (img * 255.).astype(np.uint8) + return img + + def save_grayscale_image(self, filename, img, data_range=DEFAULT_GRAYSCALE_KWARGS['data_range'], cmap=DEFAULT_GRAYSCALE_KWARGS['cmap']): + img = self.get_grayscale_image_(img, data_range, cmap) + cv2.imwrite(self.get_save_path(filename), img) + + def get_image_grid_(self, imgs): + if isinstance(imgs[0], list): + return np.concatenate([self.get_image_grid_(row) for row in imgs], axis=0) + cols = [] + for col in imgs: + assert col['type'] in ['rgb', 'uv', 'grayscale'] + if col['type'] == 'rgb': + rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy() + rgb_kwargs.update(col['kwargs']) + cols.append(self.get_rgb_image_(col['img'], **rgb_kwargs)) + elif col['type'] == 'uv': + uv_kwargs = self.DEFAULT_UV_KWARGS.copy() + uv_kwargs.update(col['kwargs']) + cols.append(self.get_uv_image_(col['img'], **uv_kwargs)) + elif col['type'] == 'grayscale': + grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy() + grayscale_kwargs.update(col['kwargs']) + cols.append(self.get_grayscale_image_(col['img'], **grayscale_kwargs)) + return np.concatenate(cols, axis=1) + + def save_image_grid(self, filename, imgs): + img = self.get_image_grid_(imgs) + cv2.imwrite(self.get_save_path(filename), img) + + def save_image(self, filename, img): + img = self.convert_data(img) + assert img.dtype == np.uint8 + if img.shape[-1] == 3: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif img.shape[-1] == 4: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) + cv2.imwrite(self.get_save_path(filename), img) + + def save_cubemap(self, filename, img, data_range=(0, 1)): + img = self.convert_data(img) + assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2] + + imgs_full = [] + for start in range(0, img.shape[-1], 3): + img_ = img[...,start:start+3] + img_ = np.stack([self.get_rgb_image_(img_[i], 'HWC', data_range) for i in range(img_.shape[0])], axis=0) + size = img_.shape[1] + placeholder = np.zeros((size, size, 3), dtype=np.float32) + img_full = np.concatenate([ + np.concatenate([placeholder, img_[2], placeholder, placeholder], axis=1), + np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1), + np.concatenate([placeholder, img_[3], placeholder, placeholder], axis=1) + ], axis=0) + img_full = cv2.cvtColor(img_full, cv2.COLOR_RGB2BGR) + imgs_full.append(img_full) + + imgs_full = np.concatenate(imgs_full, axis=1) + cv2.imwrite(self.get_save_path(filename), imgs_full) + + def save_data(self, filename, data): + data = self.convert_data(data) + if isinstance(data, dict): + if not filename.endswith('.npz'): + filename += '.npz' + np.savez(self.get_save_path(filename), **data) + else: + if not filename.endswith('.npy'): + filename += '.npy' + np.save(self.get_save_path(filename), data) + + def save_state_dict(self, filename, data): + torch.save(data, self.get_save_path(filename)) + + def save_img_sequence(self, filename, img_dir, matcher, save_format='gif', fps=30): + assert save_format in ['gif', 'mp4'] + if not filename.endswith(save_format): + filename += f".{save_format}" + matcher = re.compile(matcher) + img_dir = os.path.join(self.save_dir, img_dir) + imgs = [] + for f in os.listdir(img_dir): + if matcher.search(f): + imgs.append(f) + imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0])) + imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs] + + if save_format == 'gif': + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(self.get_save_path(filename), imgs, fps=fps, palettesize=256) + elif save_format == 'mp4': + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(self.get_save_path(filename), imgs, fps=fps) + + def save_mesh(self, filename, v_pos, t_pos_idx, v_tex=None, t_tex_idx=None, v_rgb=None, ortho_scale=1): + v_pos, t_pos_idx = self.convert_data(v_pos), self.convert_data(t_pos_idx) + if v_rgb is not None: + v_rgb = self.convert_data(v_rgb) + + if ortho_scale is not None: + print("ortho scale is: ", ortho_scale) + v_pos = v_pos * ortho_scale * 0.5 + + # change to front-facing + v_pos_copy = np.zeros_like(v_pos) + v_pos_copy[:, 0] = v_pos[:, 0] + v_pos_copy[:, 1] = v_pos[:, 2] + v_pos_copy[:, 2] = v_pos[:, 1] + + import trimesh + mesh = trimesh.Trimesh( + vertices=v_pos_copy, + faces=t_pos_idx, + vertex_colors=v_rgb + ) + trimesh.repair.fix_inversion(mesh) + mesh.export(self.get_save_path(filename)) + # mesh.export(self.get_save_path(filename.replace(".obj", "-meshlab.obj"))) + + # v_pos_copy[:, 0] = v_pos[:, 1] * -1 + # v_pos_copy[:, 1] = v_pos[:, 0] + # v_pos_copy[:, 2] = v_pos[:, 2] + + # mesh = trimesh.Trimesh( + # vertices=v_pos_copy, + # faces=t_pos_idx, + # vertex_colors=v_rgb + # ) + # mesh.export(self.get_save_path(filename.replace(".obj", "-blender.obj"))) + + + # v_pos_copy[:, 0] = v_pos[:, 0] + # v_pos_copy[:, 1] = v_pos[:, 1] * -1 + # v_pos_copy[:, 2] = v_pos[:, 2] * -1 + + # mesh = trimesh.Trimesh( + # vertices=v_pos_copy, + # faces=t_pos_idx, + # vertex_colors=v_rgb + # ) + # mesh.export(self.get_save_path(filename.replace(".obj", "-opengl.obj"))) + + def save_file(self, filename, src_path): + shutil.copyfile(src_path, self.get_save_path(filename)) + + def save_json(self, filename, payload): + with open(self.get_save_path(filename), 'w') as f: + f.write(json.dumps(payload)) diff --git a/instant-nsr-pl/utils/obj.py b/instant-nsr-pl/utils/obj.py new file mode 100644 index 0000000000000000000000000000000000000000..da6d11938c244a6b982ec8c3f36a90a7a5fd2831 --- /dev/null +++ b/instant-nsr-pl/utils/obj.py @@ -0,0 +1,74 @@ +import numpy as np + + +def load_obj(filename): + # Read entire file + with open(filename, 'r') as f: + lines = f.readlines() + + # load vertices + vertices, texcoords = [], [] + for line in lines: + if len(line.split()) == 0: + continue + + prefix = line.split()[0].lower() + if prefix == 'v': + vertices.append([float(v) for v in line.split()[1:]]) + elif prefix == 'vt': + val = [float(v) for v in line.split()[1:]] + texcoords.append([val[0], 1.0 - val[1]]) + + uv = len(texcoords) > 0 + faces, tfaces = [], [] + for line in lines: + if len(line.split()) == 0: + continue + prefix = line.split()[0].lower() + if prefix == 'usemtl': # Track used materials + pass + elif prefix == 'f': # Parse face + vs = line.split()[1:] + nv = len(vs) + vv = vs[0].split('/') + v0 = int(vv[0]) - 1 + if uv: + t0 = int(vv[1]) - 1 if vv[1] != "" else -1 + for i in range(nv - 2): # Triangulate polygons + vv1 = vs[i + 1].split('/') + v1 = int(vv1[0]) - 1 + vv2 = vs[i + 2].split('/') + v2 = int(vv2[0]) - 1 + faces.append([v0, v1, v2]) + if uv: + t1 = int(vv1[1]) - 1 if vv1[1] != "" else -1 + t2 = int(vv2[1]) - 1 if vv2[1] != "" else -1 + tfaces.append([t0, t1, t2]) + vertices = np.array(vertices, dtype=np.float32) + faces = np.array(faces, dtype=np.int64) + if uv: + assert len(tfaces) == len(faces) + texcoords = np.array(texcoords, dtype=np.float32) + tfaces = np.array(tfaces, dtype=np.int64) + else: + texcoords, tfaces = None, None + + return vertices, faces, texcoords, tfaces + + +def write_obj(filename, v_pos, t_pos_idx, v_tex, t_tex_idx): + with open(filename, "w") as f: + for v in v_pos: + f.write('v {} {} {} \n'.format(v[0], v[1], v[2])) + + if v_tex is not None: + assert(len(t_pos_idx) == len(t_tex_idx)) + for v in v_tex: + f.write('vt {} {} \n'.format(v[0], 1.0 - v[1])) + + # Write faces + for i in range(len(t_pos_idx)): + f.write("f ") + for j in range(3): + f.write(' %s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1))) + f.write("\n") diff --git a/mvdiffusion/data/fixed_poses/nine_views/000_back_RT.txt b/mvdiffusion/data/fixed_poses/nine_views/000_back_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..7ef2610bdbde3f9c9db89c05fc5606362adc7d0c --- /dev/null +++ b/mvdiffusion/data/fixed_poses/nine_views/000_back_RT.txt @@ -0,0 +1,3 @@ +-5.266582965850830078e-01 7.410295009613037109e-01 -4.165407419204711914e-01 -5.960464477539062500e-08 +5.865638996738198330e-08 4.900035560131072998e-01 8.717204332351684570e-01 -9.462351613365171943e-08 +8.500770330429077148e-01 4.590988159179687500e-01 -2.580644786357879639e-01 -1.300000071525573730e+00 diff --git a/mvdiffusion/data/fixed_poses/nine_views/000_back_left_RT.txt b/mvdiffusion/data/fixed_poses/nine_views/000_back_left_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..7db25bbbec8a0d5a26724aa65681603e5bee6744 --- /dev/null +++ b/mvdiffusion/data/fixed_poses/nine_views/000_back_left_RT.txt @@ -0,0 +1,3 @@ +-9.734988808631896973e-01 1.993551850318908691e-01 -1.120596975088119507e-01 -1.713633537292480469e-07 +3.790224578636980368e-09 4.900034964084625244e-01 8.717204928398132324e-01 1.772203575001185527e-07 +2.286916375160217285e-01 8.486189246177673340e-01 -4.770178496837615967e-01 -1.838477611541748047e+00 diff --git a/mvdiffusion/data/fixed_poses/nine_views/000_back_right_RT.txt b/mvdiffusion/data/fixed_poses/nine_views/000_back_right_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..be45ed8f30f6625421524d141aa1c325dc2fdb8b --- /dev/null +++ b/mvdiffusion/data/fixed_poses/nine_views/000_back_right_RT.txt @@ -0,0 +1,3 @@ +2.286914736032485962e-01 8.486190438270568848e-01 -4.770178198814392090e-01 1.564621925354003906e-07 +-3.417914484771245043e-08 4.900034070014953613e-01 8.717205524444580078e-01 -7.293811421504869941e-08 +9.734990000724792480e-01 -1.993550658226013184e-01 1.120596155524253845e-01 -1.838477969169616699e+00 diff --git a/mvdiffusion/data/fixed_poses/nine_views/000_front_RT.txt b/mvdiffusion/data/fixed_poses/nine_views/000_front_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..8278639ec5ec9d0f1e4c88d54295a0cd4acee593 --- /dev/null +++ b/mvdiffusion/data/fixed_poses/nine_views/000_front_RT.txt @@ -0,0 +1,3 @@ +5.266583561897277832e-01 -7.410295009613037109e-01 4.165407419204711914e-01 0.000000000000000000e+00 +5.865638996738198330e-08 4.900035560131072998e-01 8.717204332351684570e-01 9.462351613365171943e-08 +-8.500770330429077148e-01 -4.590988159179687500e-01 2.580645382404327393e-01 -1.300000071525573730e+00 diff --git a/mvdiffusion/data/fixed_poses/nine_views/000_front_left_RT.txt b/mvdiffusion/data/fixed_poses/nine_views/000_front_left_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..6255b9f84ccb1bf3527897ef648811ff171aa025 --- /dev/null +++ b/mvdiffusion/data/fixed_poses/nine_views/000_front_left_RT.txt @@ -0,0 +1,3 @@ +-2.286916971206665039e-01 -8.486189842224121094e-01 4.770179092884063721e-01 -2.458691596984863281e-07 +9.085837859856837895e-09 4.900034666061401367e-01 8.717205524444580078e-01 1.205695667749751010e-07 +-9.734990000724792480e-01 1.993551701307296753e-01 -1.120597645640373230e-01 -1.838477969169616699e+00 diff --git a/mvdiffusion/data/fixed_poses/nine_views/000_front_right_RT.txt b/mvdiffusion/data/fixed_poses/nine_views/000_front_right_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..e1d76c85c1a05de1d6bf3dd70ed02721a40f73c2 --- /dev/null +++ b/mvdiffusion/data/fixed_poses/nine_views/000_front_right_RT.txt @@ -0,0 +1,3 @@ +9.734989404678344727e-01 -1.993551850318908691e-01 1.120596975088119507e-01 -1.415610313415527344e-07 +3.790224578636980368e-09 4.900034964084625244e-01 8.717204928398132324e-01 -1.772203575001185527e-07 +-2.286916375160217285e-01 -8.486189246177673340e-01 4.770178794860839844e-01 -1.838477611541748047e+00 diff --git a/mvdiffusion/data/fixed_poses/nine_views/000_left_RT.txt b/mvdiffusion/data/fixed_poses/nine_views/000_left_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..bd42197eaae14526b00cb4676528a2465cbaf1dd --- /dev/null +++ b/mvdiffusion/data/fixed_poses/nine_views/000_left_RT.txt @@ -0,0 +1,3 @@ +-8.500771522521972656e-01 -4.590989053249359131e-01 2.580644488334655762e-01 0.000000000000000000e+00 +-4.257411134744870651e-08 4.900034964084625244e-01 8.717204928398132324e-01 9.006067358541258727e-08 +-5.266583561897277832e-01 7.410295605659484863e-01 -4.165408313274383545e-01 -1.300000071525573730e+00 diff --git a/mvdiffusion/data/fixed_poses/nine_views/000_right_RT.txt b/mvdiffusion/data/fixed_poses/nine_views/000_right_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..2d37c0219db99aeeede6b48815b932058538adc6 --- /dev/null +++ b/mvdiffusion/data/fixed_poses/nine_views/000_right_RT.txt @@ -0,0 +1,3 @@ +8.500770330429077148e-01 4.590989053249359131e-01 -2.580644488334655762e-01 5.960464477539062500e-08 +-4.257411134744870651e-08 4.900034964084625244e-01 8.717204928398132324e-01 -9.006067358541258727e-08 +5.266583561897277832e-01 -7.410295605659484863e-01 4.165407419204711914e-01 -1.300000071525573730e+00 diff --git a/mvdiffusion/data/fixed_poses/nine_views/000_top_RT.txt b/mvdiffusion/data/fixed_poses/nine_views/000_top_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..1d71f22664b502cd5e1039f4621bec6ef41b1231 --- /dev/null +++ b/mvdiffusion/data/fixed_poses/nine_views/000_top_RT.txt @@ -0,0 +1,3 @@ +9.958608150482177734e-01 7.923202216625213623e-02 -4.453715682029724121e-02 -3.098167056236889039e-09 +-9.089154005050659180e-02 8.681122064590454102e-01 -4.879753291606903076e-01 5.784738377201392723e-08 +-2.028124157504862524e-08 4.900035560131072998e-01 8.717204332351684570e-01 -1.300000071525573730e+00 diff --git a/mvdiffusion/data/normal_utils.py b/mvdiffusion/data/normal_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dff3730a312e96a2ed82dfd5a337d263baa0f2d8 --- /dev/null +++ b/mvdiffusion/data/normal_utils.py @@ -0,0 +1,45 @@ +import numpy as np + +def camNormal2worldNormal(rot_c2w, camNormal): + H,W,_ = camNormal.shape + normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) + + return normal_img + +def worldNormal2camNormal(rot_w2c, normal_map_world): + H,W,_ = normal_map_world.shape + # normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) + + # faster version + # Reshape the normal map into a 2D array where each row represents a normal vector + normal_map_flat = normal_map_world.reshape(-1, 3) + + # Transform the normal vectors using the transformation matrix + normal_map_camera_flat = np.dot(normal_map_flat, rot_w2c.T) + + # Reshape the transformed normal map back to its original shape + normal_map_camera = normal_map_camera_flat.reshape(normal_map_world.shape) + + return normal_map_camera + +def trans_normal(normal, RT_w2c, RT_w2c_target): + + # normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal) + # normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world) + + relative_RT = np.matmul(RT_w2c_target[:3,:3], np.linalg.inv(RT_w2c[:3,:3])) + normal_target_cam = worldNormal2camNormal(relative_RT[:3,:3], normal) + + return normal_target_cam + +def img2normal(img): + return (img/255.)*2-1 + +def normal2img(normal): + return np.uint8((normal*0.5+0.5)*255) + +def norm_normalize(normal, dim=-1): + + normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6) + + return normal \ No newline at end of file diff --git a/mvdiffusion/data/single_image_dataset.py b/mvdiffusion/data/single_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b1b28b54ce60fef6e4513c5f41e2b0c7dab662b4 --- /dev/null +++ b/mvdiffusion/data/single_image_dataset.py @@ -0,0 +1,305 @@ +from typing import Dict +import numpy as np +from omegaconf import DictConfig, ListConfig +import torch +from torch.utils.data import Dataset +from pathlib import Path +import json +from PIL import Image +from torchvision import transforms +from einops import rearrange +from typing import Literal, Tuple, Optional, Any +import cv2 +import random + +import json +import os, sys +import math + +from glob import glob + +import PIL.Image +from .normal_utils import trans_normal, normal2img, img2normal +import pdb + + +import cv2 +import numpy as np + +def add_margin(pil_img, color=0, size=256): + width, height = pil_img.size + result = Image.new(pil_img.mode, (size, size), color) + result.paste(pil_img, ((size - width) // 2, (size - height) // 2)) + return result + +def scale_and_place_object(image, scale_factor): + assert np.shape(image)[-1]==4 # RGBA + + # Extract the alpha channel (transparency) and the object (RGB channels) + alpha_channel = image[:, :, 3] + + # Find the bounding box coordinates of the object + coords = cv2.findNonZero(alpha_channel) + x, y, width, height = cv2.boundingRect(coords) + + # Calculate the scale factor for resizing + original_height, original_width = image.shape[:2] + + if width > height: + size = width + original_size = original_width + else: + size = height + original_size = original_height + + scale_factor = min(scale_factor, size / (original_size+0.0)) + + new_size = scale_factor * original_size + scale_factor = new_size / size + + # Calculate the new size based on the scale factor + new_width = int(width * scale_factor) + new_height = int(height * scale_factor) + + center_x = original_width // 2 + center_y = original_height // 2 + + paste_x = center_x - (new_width // 2) + paste_y = center_y - (new_height // 2) + + # Resize the object (RGB channels) to the new size + rescaled_object = cv2.resize(image[y:y+height, x:x+width], (new_width, new_height)) + + # Create a new RGBA image with the resized image + new_image = np.zeros((original_height, original_width, 4), dtype=np.uint8) + + new_image[paste_y:paste_y + new_height, paste_x:paste_x + new_width] = rescaled_object + + return new_image + +class SingleImageDataset(Dataset): + def __init__(self, + root_dir: str, + num_views: int, + img_wh: Tuple[int, int], + bg_color: str, + crop_size: int = 224, + single_image: Optional[PIL.Image.Image] = None, + num_validation_samples: Optional[int] = None, + filepaths: Optional[list] = None, + cond_type: Optional[str] = None + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = root_dir + self.num_views = num_views + self.img_wh = img_wh + self.crop_size = crop_size + self.bg_color = bg_color + self.cond_type = cond_type + + if self.num_views == 4: + self.view_types = ['front', 'right', 'back', 'left'] + elif self.num_views == 5: + self.view_types = ['front', 'front_right', 'right', 'back', 'left'] + elif self.num_views == 6: + self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left'] + + self.fix_cam_pose_dir = "./mvdiffusion/data/fixed_poses/nine_views" + + self.fix_cam_poses = self.load_fixed_poses() # world2cam matrix + + if single_image is None: + if filepaths is None: + # Get a list of all files in the directory + file_list = os.listdir(self.root_dir) + else: + file_list = filepaths + + # Filter the files that end with .png or .jpg + self.file_list = [file for file in file_list if file.endswith(('.png', '.jpg'))] + else: + self.file_list = None + + # load all images + self.all_images = [] + self.all_alphas = [] + bg_color = self.get_bg_color() + + if single_image is not None: + image, alpha = self.load_image(None, bg_color, return_type='pt', Imagefile=single_image) + self.all_images.append(image) + self.all_alphas.append(alpha) + else: + for file in self.file_list: + print(os.path.join(self.root_dir, file)) + image, alpha = self.load_image(os.path.join(self.root_dir, file), bg_color, return_type='pt') + self.all_images.append(image) + self.all_alphas.append(alpha) + + self.all_images = self.all_images[:num_validation_samples] + self.all_alphas = self.all_alphas[:num_validation_samples] + + + def __len__(self): + return len(self.all_images) + + def load_fixed_poses(self): + poses = {} + for face in self.view_types: + RT = np.loadtxt(os.path.join(self.fix_cam_pose_dir,'%03d_%s_RT.txt'%(0, face))) + poses[face] = RT + + return poses + + def cartesian_to_spherical(self, xyz): + ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) + xy = xyz[:,0]**2 + xyz[:,1]**2 + z = np.sqrt(xy + xyz[:,2]**2) + theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down + #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up + azimuth = np.arctan2(xyz[:,1], xyz[:,0]) + return np.array([theta, azimuth, z]) + + def get_T(self, target_RT, cond_RT): + R, T = target_RT[:3, :3], target_RT[:, -1] + T_target = -R.T @ T # change to cam2world + + R, T = cond_RT[:3, :3], cond_RT[:, -1] + T_cond = -R.T @ T + + theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :]) + theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :]) + + d_theta = theta_target - theta_cond + d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) + d_z = z_target - z_cond + + # d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()]) + return d_theta, d_azimuth + + def get_bg_color(self): + if self.bg_color == 'white': + bg_color = np.array([1., 1., 1.], dtype=np.float32) + elif self.bg_color == 'black': + bg_color = np.array([0., 0., 0.], dtype=np.float32) + elif self.bg_color == 'gray': + bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32) + elif self.bg_color == 'random': + bg_color = np.random.rand(3) + elif isinstance(self.bg_color, float): + bg_color = np.array([self.bg_color] * 3, dtype=np.float32) + else: + raise NotImplementedError + return bg_color + + + def load_image(self, img_path, bg_color, return_type='np', Imagefile=None): + # pil always returns uint8 + if Imagefile is None: + image_input = Image.open(img_path) + else: + image_input = Imagefile + image_size = self.img_wh[0] + + if self.crop_size!=-1: + alpha_np = np.asarray(image_input)[:, :, 3] + coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)] + min_x, min_y = np.min(coords, 0) + max_x, max_y = np.max(coords, 0) + ref_img_ = image_input.crop((min_x, min_y, max_x, max_y)) + h, w = ref_img_.height, ref_img_.width + scale = self.crop_size / max(h, w) + h_, w_ = int(scale * h), int(scale * w) + ref_img_ = ref_img_.resize((w_, h_)) + image_input = add_margin(ref_img_, size=image_size) + else: + image_input = add_margin(image_input, size=max(image_input.height, image_input.width)) + image_input = image_input.resize((image_size, image_size)) + + # img = scale_and_place_object(img, self.scale_ratio) + img = np.array(image_input) + img = img.astype(np.float32) / 255. # [0, 1] + assert img.shape[-1] == 4 # RGBA + + alpha = img[...,3:4] + img = img[...,:3] * alpha + bg_color * (1 - alpha) + + if return_type == "np": + pass + elif return_type == "pt": + img = torch.from_numpy(img) + alpha = torch.from_numpy(alpha) + else: + raise NotImplementedError + + return img, alpha + + + def __len__(self): + return len(self.all_images) + + def __getitem__(self, index): + + image = self.all_images[index%len(self.all_images)] + alpha = self.all_alphas[index%len(self.all_images)] + if self.file_list is not None: + filename = self.file_list[index%len(self.all_images)].replace(".png", "") + else: + filename = 'null' + + cond_w2c = self.fix_cam_poses['front'] + + tgt_w2cs = [self.fix_cam_poses[view] for view in self.view_types] + + elevations = [] + azimuths = [] + + img_tensors_in = [ + image.permute(2, 0, 1) + ] * self.num_views + + alpha_tensors_in = [ + alpha.permute(2, 0, 1) + ] * self.num_views + + for view, tgt_w2c in zip(self.view_types, tgt_w2cs): + # evelations, azimuths + elevation, azimuth = self.get_T(tgt_w2c, cond_w2c) + elevations.append(elevation) + azimuths.append(azimuth) + + img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W) + alpha_tensors_in = torch.stack(alpha_tensors_in, dim=0).float() # (Nv, 3, H, W) + + elevations = torch.as_tensor(elevations).float().squeeze(1) + azimuths = torch.as_tensor(azimuths).float().squeeze(1) + elevations_cond = torch.as_tensor([0] * self.num_views).float() + + normal_class = torch.tensor([1, 0]).float() + normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2) + color_class = torch.tensor([0, 1]).float() + color_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2) + + camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1) # (Nv, 3) + + out = { + 'elevations_cond': elevations_cond, + 'elevations_cond_deg': torch.rad2deg(elevations_cond), + 'elevations': elevations, + 'azimuths': azimuths, + 'elevations_deg': torch.rad2deg(elevations), + 'azimuths_deg': torch.rad2deg(azimuths), + 'imgs_in': img_tensors_in, + 'alphas': alpha_tensors_in, + 'camera_embeddings': camera_embeddings, + 'normal_task_embeddings': normal_task_embeddings, + 'color_task_embeddings': color_task_embeddings, + 'filename': filename, + } + + return out + + diff --git a/mvdiffusion/models/transformer_mv2d.py b/mvdiffusion/models/transformer_mv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f0be6bcb554af916fe74ee7d919d0a1a47aa36 --- /dev/null +++ b/mvdiffusion/models/transformer_mv2d.py @@ -0,0 +1,986 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import BaseOutput, deprecate, maybe_allow_in_graph +from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention +from diffusers.models.embeddings import PatchEmbed +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.import_utils import is_xformers_available + +from einops import rearrange, repeat +import pdb +import random + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + +def my_repeat(tensor, num_repeats): + """ + Repeat a tensor along a given dimension + """ + if len(tensor.shape) == 3: + return repeat(tensor, "b d c -> (b v) d c", v=num_repeats) + elif len(tensor.shape) == 4: + return repeat(tensor, "a b d c -> (a v) b d c", v=num_repeats) + + +@dataclass +class TransformerMV2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class TransformerMV2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + num_views: int = 1, + cd_attention_last: bool=False, + cd_attention_mid: bool=False, + multiview_attention: bool=True, + sparse_mv_attention: bool = False, + mvcd_attention: bool=False + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) + else: + self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicMVTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + mvcd_attention=mvcd_attention + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) + else: + self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return TransformerMV2DModelOutput(sample=output) + + +@maybe_allow_in_graph +class BasicMVTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + mvcd_attention: bool = False + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + self.multiview_attention = multiview_attention + self.sparse_mv_attention = sparse_mv_attention + self.mvcd_attention = mvcd_attention + + self.attn1 = CustomAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=MVAttnProcessor() + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + self.num_views = num_views + + self.cd_attention_last = cd_attention_last + + if self.cd_attention_last: + # Joint task -Attn + self.attn_joint_last = CustomJointAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=JointAttnProcessor() + ) + nn.init.zeros_(self.attn_joint_last.to_out[0].weight.data) + self.norm_joint_last = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + + self.cd_attention_mid = cd_attention_mid + + if self.cd_attention_mid: + # print("cross-domain attn in the middle") + # Joint task -Attn + self.attn_joint_mid = CustomJointAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=JointAttnProcessor() + ) + nn.init.zeros_(self.attn_joint_mid.to_out[0].weight.data) + self.norm_joint_mid = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + assert attention_mask is None # not supported yet + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + num_views=self.num_views, + multiview_attention=self.multiview_attention, + sparse_mv_attention=self.sparse_mv_attention, + mvcd_attention=self.mvcd_attention, + **cross_attention_kwargs, + ) + + + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # joint attention twice + if self.cd_attention_mid: + norm_hidden_states = ( + self.norm_joint_mid(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_mid(hidden_states) + ) + hidden_states = self.attn_joint_mid(norm_hidden_states) + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + if self.cd_attention_last: + norm_hidden_states = ( + self.norm_joint_last(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_last(hidden_states) + ) + hidden_states = self.attn_joint_last(norm_hidden_states) + hidden_states + + return hidden_states + + +class CustomAttention(Attention): + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, *args, **kwargs + ): + processor = XFormersMVAttnProcessor() + self.set_processor(processor) + # print("using xformers attention processor") + + +class CustomJointAttention(Attention): + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, *args, **kwargs + ): + processor = XFormersJointAttnProcessor() + self.set_processor(processor) + # print("using xformers attention processor") + +class MVAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_views=1, + multiview_attention=True + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # print('query', query.shape, 'key', key.shape, 'value', value.shape) + #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) + # pdb.set_trace() + # multi-view self-attention + if multiview_attention: + key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) + value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class XFormersMVAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_views=1., + multiview_attention=True, + sparse_mv_attention=False, + mvcd_attention=False, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + # from yuancheng; here attention_mask is None + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key_raw = attn.to_k(encoder_hidden_states) + value_raw = attn.to_v(encoder_hidden_states) + + # print('query', query.shape, 'key', key.shape, 'value', value.shape) + #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) + # pdb.set_trace() + # multi-view self-attention + if multiview_attention: + if not sparse_mv_attention: + key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views) + value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views) + else: + key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c] + value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) + key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c + value = torch.cat([value_front, value_raw], dim=1) + + else: + # print("don't use multiview attention.") + key = key_raw + value = value_raw + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + + +class XFormersJointAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_tasks=2 + ): + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + # from yuancheng; here attention_mask is None + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + assert num_tasks == 2 # only support two tasks now + + key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c + value_0, value_1 = torch.chunk(value, dim=0, chunks=2) + key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c + value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c + key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c + value = torch.cat([value]*2, dim=0) # (2 b t) 2d c + + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class JointAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_tasks=2 + ): + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + assert num_tasks == 2 # only support two tasks now + + key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c + value_0, value_1 = torch.chunk(value, dim=0, chunks=2) + key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c + value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c + key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c + value = torch.cat([value]*2, dim=0) # (2 b t) 2d c + + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states \ No newline at end of file diff --git a/mvdiffusion/models/unet_mv2d_blocks.py b/mvdiffusion/models/unet_mv2d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..4dfba32e67879ae52b14ef66173830dbf6cdd520 --- /dev/null +++ b/mvdiffusion/models/unet_mv2d_blocks.py @@ -0,0 +1,922 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import is_torch_version, logging +from diffusers.models.attention import AdaGroupNorm +from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 +from diffusers.models.dual_transformer_2d import DualTransformer2DModel +from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D +from mvdiffusion.models.transformer_mv2d import TransformerMV2DModel + +from diffusers.models.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D +from diffusers.models.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + downsample_type=None, + num_views=1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + mvcd_attention: bool=False +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif down_block_type == "AttnDownBlock2D": + if add_downsample is False: + downsample_type = None + else: + downsample_type = downsample_type or "conv" # default to 'conv' + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + # custom MV2D attention block + elif down_block_type == "CrossAttnDownBlockMV2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D") + return CrossAttnDownBlockMV2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + mvcd_attention=mvcd_attention + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + add_self_attention=True if not add_downsample else False, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + upsample_type=None, + num_views=1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + mvcd_attention: bool=False +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + # custom MV2D attention block + elif up_block_type == "CrossAttnUpBlockMV2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D") + return CrossAttnUpBlockMV2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + mvcd_attention=mvcd_attention + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif up_block_type == "AttnUpBlock2D": + if add_upsample is False: + upsample_type = None + else: + upsample_type = upsample_type or "conv" # default to 'conv' + + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlockMV2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + mvcd_attention: bool=False + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + TransformerMV2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + mvcd_attention=mvcd_attention + ) + ) + else: + raise NotImplementedError + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlockMV2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + mvcd_attention: bool=False + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + TransformerMV2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + mvcd_attention=mvcd_attention + ) + ) + else: + raise NotImplementedError + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class CrossAttnDownBlockMV2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + mvcd_attention: bool=False + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + TransformerMV2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + mvcd_attention=mvcd_attention + ) + ) + else: + raise NotImplementedError + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals=None, + ): + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + diff --git a/mvdiffusion/models/unet_mv2d_condition.py b/mvdiffusion/models/unet_mv2d_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..ab9ae7dc6d2ee08ee935939023680d074ab15f92 --- /dev/null +++ b/mvdiffusion/models/unet_mv2d_condition.py @@ -0,0 +1,1492 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +import os + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor +from diffusers.models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model +from diffusers.models.unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + UpBlock2D, +) +from diffusers.utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + FLAX_WEIGHTS_NAME, + HF_HUB_OFFLINE, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + _add_variant, + _get_model_file, + deprecate, + is_accelerate_available, + is_safetensors_available, + is_torch_version, + logging, +) +from diffusers import __version__ +from mvdiffusion.models.unet_mv2d_blocks import ( + CrossAttnDownBlockMV2D, + CrossAttnUpBlockMV2D, + UNetMidBlockMV2DCrossAttn, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetMV2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + mvcd_attention: bool = False + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + mvcd_attention=mvcd_attention + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + # custom MV2D attention block + elif mid_block_type == "UNetMidBlockMV2DCrossAttn": + self.mid_block = UNetMidBlockMV2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + mvcd_attention=mvcd_attention + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + mvcd_attention=mvcd_attention + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNetMV2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNetMV2DConditionOutput(sample=sample) + + @classmethod + def from_pretrained_2d( + cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + camera_embedding_type: str, num_views: int, sample_size: int, + zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False, + projection_class_embeddings_input_dim: int=6, cd_attention_last: bool = False, + cd_attention_mid: bool = False, multiview_attention: bool = True, + sparse_mv_attention: bool = False, mvcd_attention: bool = False, + in_channels: int = 8, out_channels: int = 4, + **kwargs + ): + r""" + Instantiate a pretrained PyTorch model from a pretrained model configuration. + + The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To + train the model, set it back in training mode with `model.train()`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`~ModelMixin.save_pretrained`]. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info (`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if `device_map` contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + variant (`str`, *optional*): + Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` + weights. If set to `False`, `safetensors` weights are not loaded. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. + + + + Example: + + ```py + from diffusers import UNet2DConditionModel + + unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") + ``` + + If you get the error message below, you need to finetune the weights for your downstream task: + + ```bash + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + # Check if we can handle device_map and dispatching the weights + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + user_agent=user_agent, + **kwargs, + ) + + # modify config + config["_class_name"] = cls.__name__ + config['in_channels'] = in_channels + config['out_channels'] = out_channels + config['sample_size'] = sample_size # training resolution + config['num_views'] = num_views + config['cd_attention_last'] = cd_attention_last + config['cd_attention_mid'] = cd_attention_mid + config['multiview_attention'] = multiview_attention + config['sparse_mv_attention'] = sparse_mv_attention + config['mvcd_attention'] = mvcd_attention + config["down_block_types"] = [ + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "DownBlock2D" + ] + config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn" + config["up_block_types"] = [ + "UpBlock2D", + "CrossAttnUpBlockMV2D", + "CrossAttnUpBlockMV2D", + "CrossAttnUpBlockMV2D" + ] + config['class_embed_type'] = 'projection' + if camera_embedding_type == 'e_de_da_sincos': + config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6 + else: + raise NotImplementedError + + # load model + model_file = None + if from_flax: + raise NotImplementedError + else: + if use_safetensors: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except IOError as e: + if not allow_pickle: + raise e + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + model = cls.from_config(config, **unused_kwargs) + import copy + state_dict_v0 = load_state_dict(model_file, variant=variant) + state_dict = copy.deepcopy(state_dict_v0) + # attn_joint -> attn_joint_last; norm_joint -> norm_joint_last + # attn_joint_twice -> attn_joint_mid; norm_joint_twice -> norm_joint_mid + for key in state_dict_v0: + if 'attn_joint.' in key: + tmp = copy.deepcopy(key) + state_dict[key.replace("attn_joint.", "attn_joint_last.")] = state_dict.pop(tmp) + if 'norm_joint.' in key: + tmp = copy.deepcopy(key) + state_dict[key.replace("norm_joint.", "norm_joint_last.")] = state_dict.pop(tmp) + if 'attn_joint_twice.' in key: + tmp = copy.deepcopy(key) + state_dict[key.replace("attn_joint_twice.", "attn_joint_mid.")] = state_dict.pop(tmp) + if 'norm_joint_twice.' in key: + tmp = copy.deepcopy(key) + state_dict[key.replace("norm_joint_twice.", "norm_joint_mid.")] = state_dict.pop(tmp) + + model._convert_deprecated_attention_blocks(state_dict) + + conv_in_weight = state_dict['conv_in.weight'] + conv_out_weight = state_dict['conv_out.weight'] + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=True, + ) + if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]): + # initialize from the original SD structure + model.conv_in.weight.data[:,:4] = conv_in_weight + + # whether to place all zero to new layers? + if zero_init_conv_in: + model.conv_in.weight.data[:,4:] = 0. + + if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]): + # initialize from the original SD structure + model.conv_out.weight.data[:,:4] = conv_out_weight + if out_channels == 8: # copy for the last 4 channels + model.conv_out.weight.data[:, 4:] = conv_out_weight + + if zero_init_camera_projection: + for p in model.class_embedding.parameters(): + torch.nn.init.zeros_(p) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model_2d( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = list(state_dict.keys()) + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + diff --git a/mvdiffusion/pipelines/pipeline_mvdiffusion_image.py b/mvdiffusion/pipelines/pipeline_mvdiffusion_image.py new file mode 100644 index 0000000000000000000000000000000000000000..3414f9769336dd2564ba427a54c975b900cf423e --- /dev/null +++ b/mvdiffusion/pipelines/pipeline_mvdiffusion_image.py @@ -0,0 +1,509 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Callable, List, Optional, Union + +import PIL +import torch +import torchvision.transforms.functional as TF +from packaging import version +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import deprecate, logging, randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from einops import rearrange, repeat + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class MVDiffusionImagePipeline(DiffusionPipeline): + r""" + Pipeline to generate image variations from an input image using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + # TODO: feature_extractor is required to encode images (if they are in PIL format), + # we should give a descriptive message if the pipeline doesn't have one. + _optional_components = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + image_encoder: CLIPVisionModelWithProjection, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + camera_embedding_type: str = 'e_de_da_sincos', + num_views: int = 6 + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.camera_embedding_type: str = camera_embedding_type + self.num_views: int = num_views + + self.camera_embedding = torch.tensor( + [[ 0.0000, 0.0000, 0.0000, 1.0000, 0.0000], + [ 0.0000, -0.2362, 0.8125, 1.0000, 0.0000], + [ 0.0000, -0.1686, 1.6934, 1.0000, 0.0000], + [ 0.0000, 0.5220, 3.1406, 1.0000, 0.0000], + [ 0.0000, 0.6904, 4.8359, 1.0000, 0.0000], + [ 0.0000, 0.3733, 5.5859, 1.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [ 0.0000, -0.2362, 0.8125, 0.0000, 1.0000], + [ 0.0000, -0.1686, 1.6934, 0.0000, 1.0000], + [ 0.0000, 0.5220, 3.1406, 0.0000, 1.0000], + [ 0.0000, 0.6904, 4.8359, 0.0000, 1.0000], + [ 0.0000, 0.3733, 5.5859, 0.0000, 1.0000]], dtype=torch.float16) + + def _encode_image(self, image_pil, device, num_images_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + image_pt = self.feature_extractor(images=image_pil, return_tensors="pt").pixel_values + image_pt = image_pt.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image_pt).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + # Note: repeat differently from official pipelines + # B1B2B3B4 -> B1B2B3B4B1B2B3B4 + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(num_images_per_prompt, 1, 1) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(device).to(dtype) + image_pt = image_pt * 2.0 - 1.0 + image_latents = self.vae.encode(image_pt).latent_dist.mode() * self.vae.config.scaling_factor + # Note: repeat differently from official pipelines + # B1B2B3B4 -> B1B2B3B4B1B2B3B4 + image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1) + + if do_classifier_free_guidance: + image_latents = torch.cat([torch.zeros_like(image_latents), image_latents]) + + return image_embeddings, image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, image, height, width, callback_steps): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_camera_embedding(self, camera_embedding: Union[float, torch.Tensor], do_classifier_free_guidance, num_images_per_prompt=1): + # (B, 3) + camera_embedding = camera_embedding.to(dtype=self.unet.dtype, device=self.unet.device) + + if self.camera_embedding_type == 'e_de_da_sincos': + # (B, 6) + camera_embedding = torch.cat([ + torch.sin(camera_embedding), + torch.cos(camera_embedding) + ], dim=-1) + assert self.unet.config.class_embed_type == 'projection' + assert self.unet.config.projection_class_embeddings_input_dim == 6 or self.unet.config.projection_class_embeddings_input_dim == 10 + else: + raise NotImplementedError + + # Note: repeat differently from official pipelines + # B1B2B3B4 -> B1B2B3B4B1B2B3B4 + camera_embedding = camera_embedding.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + camera_embedding = torch.cat([ + camera_embedding, + camera_embedding + ], dim=0) + + return camera_embedding + + @torch.no_grad() + def __call__( + self, + image: Union[List[PIL.Image.Image], torch.FloatTensor], + # elevation_cond: torch.FloatTensor, + # elevation: torch.FloatTensor, + # azimuth: torch.FloatTensor, + camera_embedding: Optional[torch.FloatTensor]=None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + normal_cond: Optional[Union[List[PIL.Image.Image], torch.FloatTensor]] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + Image or images to guide image generation. If you provide a tensor, it needs to be compatible with + [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + + Examples: + + ```py + from diffusers import StableDiffusionImageVariationPipeline + from PIL import Image + from io import BytesIO + import requests + + pipe = StableDiffusionImageVariationPipeline.from_pretrained( + "lambdalabs/sd-image-variations-diffusers", revision="v2.0" + ) + pipe = pipe.to("cuda") + + url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200" + + response = requests.get(url) + image = Image.open(BytesIO(response.content)).convert("RGB") + + out = pipe(image, num_images_per_prompt=3, guidance_scale=15) + out["images"][0].save("result.jpg") + ``` + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, callback_steps) + + + # 2. Define call parameters + if isinstance(image, list): + batch_size = len(image) + elif isinstance(image, torch.Tensor): + batch_size = image.shape[0] + assert batch_size >= self.num_views and batch_size % self.num_views == 0 + elif isinstance(image, PIL.Image.Image): + image = [image]*self.num_views*2 + batch_size = self.num_views*2 + + device = self._execution_device + dtype = self.vae.dtype + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale != 1.0 + + # 3. Encode input image + if isinstance(image, list): + image_pil = image + elif isinstance(image, torch.Tensor): + image_pil = [TF.to_pil_image(image[i]) for i in range(image.shape[0])] + image_embeddings, image_latents = self._encode_image(image_pil, device, num_images_per_prompt, do_classifier_free_guidance) + + if normal_cond is not None: + if isinstance(normal_cond, list): + normal_cond_pil = normal_cond + elif isinstance(normal_cond, torch.Tensor): + normal_cond_pil = [TF.to_pil_image(normal_cond[i]) for i in range(normal_cond.shape[0])] + _, image_latents = self._encode_image(normal_cond_pil, device, num_images_per_prompt, do_classifier_free_guidance) + + + # assert len(elevation_cond) == batch_size and len(elevation) == batch_size and len(azimuth) == batch_size + # camera_embeddings = self.prepare_camera_condition(elevation_cond, elevation, azimuth, do_classifier_free_guidance=do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt) + + if camera_embedding is not None: + assert len(camera_embedding) == batch_size + else: + camera_embedding = self.camera_embedding.to(dtype) + camera_embedding = repeat(camera_embedding, "Nv Nce -> (B Nv) Nce", B=batch_size//len(camera_embedding)) + camera_embeddings = self.prepare_camera_embedding(camera_embedding, do_classifier_free_guidance=do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.out_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([ + latent_model_input, image_latents + ], dim=1) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, class_labels=camera_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + if num_channels_latents == 8: + latents = torch.cat([latents[:, :4], latents[:, 4:]], dim=0) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..221bbcf0c03d8d88dea81e3e904d491626162e73 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,31 @@ +--extra-index-url https://download.pytorch.org/whl/cu117 +torch==1.13.1 +torchvision==0.14.1 +diffusers[torch]==0.19.3 +xformers==0.0.16 +transformers>=4.25.1 +bitsandbytes==0.35.4 +decord==0.6.0 +pytorch-lightning<2 +omegaconf==2.2.3 +nerfacc==0.3.3 +trimesh==3.9.8 +pyhocon==0.3.57 +icecream==2.1.0 +PyMCubes==0.1.2 +accelerate +modelcards +einops +ftfy +piq +matplotlib +opencv-python +imageio +imageio-ffmpeg +scipy +pyransac3d +torch_efficient_distloss +tensorboard +rembg +segment_anything +gradio==3.50.2 diff --git a/run_test.sh b/run_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..167046f16f70188969b67d649239cff208ceb5bf --- /dev/null +++ b/run_test.sh @@ -0,0 +1 @@ +accelerate launch --config_file 1gpu.yaml test_mvdiffusion_seq.py --config configs/mvdiffusion-joint-ortho-6views.yaml \ No newline at end of file diff --git a/test_mvdiffusion_seq.py b/test_mvdiffusion_seq.py new file mode 100644 index 0000000000000000000000000000000000000000..71c1d75a99cae09e91298b15a4133b2ca2befb4f --- /dev/null +++ b/test_mvdiffusion_seq.py @@ -0,0 +1,325 @@ +import argparse +import datetime +import logging +import inspect +import math +import os +from typing import Dict, Optional, Tuple, List +from omegaconf import OmegaConf +from PIL import Image +import cv2 +import numpy as np +from dataclasses import dataclass +from packaging import version +import shutil +from collections import defaultdict + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms.functional as TF +from torchvision.utils import make_grid, save_image + +import transformers +import accelerate +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed + +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from mvdiffusion.models.unet_mv2d_condition import UNetMV2DConditionModel + +from mvdiffusion.data.single_image_dataset import SingleImageDataset as MVDiffusionDataset + +from mvdiffusion.pipelines.pipeline_mvdiffusion_image import MVDiffusionImagePipeline + +from einops import rearrange +from rembg import remove +import pdb + +weight_dtype = torch.float16 + + +@dataclass +class TestConfig: + pretrained_model_name_or_path: str + pretrained_unet_path:str + revision: Optional[str] + validation_dataset: Dict + save_dir: str + seed: Optional[int] + validation_batch_size: int + dataloader_num_workers: int + + local_rank: int + + pipe_kwargs: Dict + pipe_validation_kwargs: Dict + unet_from_pretrained_kwargs: Dict + validation_guidance_scales: List[float] + validation_grid_nrow: int + camera_embedding_lr_mult: float + + num_views: int + camera_embedding_type: str + + pred_type: str # joint, or ablation + + enable_xformers_memory_efficient_attention: bool + + cond_on_normals: bool + cond_on_colors: bool + + +def log_validation(dataloader, pipeline, cfg: TestConfig, weight_dtype, name, save_dir): + + + pipeline.set_progress_bar_config(disable=True) + + if cfg.seed is None: + generator = None + else: + generator = torch.Generator(device=pipeline.device).manual_seed(cfg.seed) + + images_cond, images_pred = [], defaultdict(list) + for i, batch in tqdm(enumerate(dataloader)): + # (B, Nv, 3, H, W) + imgs_in = batch['imgs_in'] + alphas = batch['alphas'] + # (B, Nv, Nce) + camera_embeddings = batch['camera_embeddings'] + filename = batch['filename'] + + bsz, num_views = imgs_in.shape[0], imgs_in.shape[1] + # (B*Nv, 3, H, W) + imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W") + alphas = rearrange(alphas, "B Nv C H W -> (B Nv) C H W") + # (B*Nv, Nce) + camera_embeddings = rearrange(camera_embeddings, "B Nv Nce -> (B Nv) Nce") + + images_cond.append(imgs_in) + + with torch.autocast("cuda"): + # B*Nv images + for guidance_scale in cfg.validation_guidance_scales: + out = pipeline( + imgs_in, camera_embeddings, generator=generator, guidance_scale=guidance_scale, output_type='pt', num_images_per_prompt=1, **cfg.pipe_validation_kwargs + ).images + images_pred[f"{name}-sample_cfg{guidance_scale:.1f}"].append(out) + cur_dir = os.path.join(save_dir, f"cropsize-{cfg.validation_dataset.crop_size}-cfg{guidance_scale:.1f}") + + # pdb.set_trace() + for i in range(bsz): + scene = os.path.basename(filename[i]) + print(scene) + scene_dir = os.path.join(cur_dir, scene) + outs_dir = os.path.join(scene_dir, "outs") + masked_outs_dir = os.path.join(scene_dir, "masked_outs") + os.makedirs(outs_dir, exist_ok=True) + os.makedirs(masked_outs_dir, exist_ok=True) + img_in = imgs_in[i*num_views] + alpha = alphas[i*num_views] + img_in = torch.cat([img_in, alpha], dim=0) + save_image(img_in, os.path.join(scene_dir, scene+".png")) + for j in range(num_views): + view = VIEWS[j] + idx = i*num_views + j + pred = out[idx] + + # pdb.set_trace() + out_filename = f"{cfg.pred_type}_000_{view}.png" + pred = save_image(pred, os.path.join(outs_dir, out_filename)) + + rm_pred = remove(pred) + + save_image_numpy(rm_pred, os.path.join(scene_dir, out_filename)) + torch.cuda.empty_cache() + + + +def save_image(tensor, fp): + ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + # pdb.set_trace() + im = Image.fromarray(ndarr) + im.save(fp) + return ndarr + +def save_image_numpy(ndarr, fp): + im = Image.fromarray(ndarr) + im.save(fp) + +def log_validation_joint(dataloader, pipeline, cfg: TestConfig, weight_dtype, name, save_dir): + + pipeline.set_progress_bar_config(disable=True) + + if cfg.seed is None: + generator = None + else: + generator = torch.Generator(device=pipeline.device).manual_seed(cfg.seed) + + images_cond, normals_pred, images_pred = [], defaultdict(list), defaultdict(list) + for i, batch in tqdm(enumerate(dataloader)): + # repeat (2B, Nv, 3, H, W) + imgs_in = torch.cat([batch['imgs_in']]*2, dim=0) + + filename = batch['filename'] + + # (2B, Nv, Nce) + camera_embeddings = torch.cat([batch['camera_embeddings']]*2, dim=0) + + task_embeddings = torch.cat([batch['normal_task_embeddings'], batch['color_task_embeddings']], dim=0) + + camera_embeddings = torch.cat([camera_embeddings, task_embeddings], dim=-1) + + # (B*Nv, 3, H, W) + imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W") + # (B*Nv, Nce) + camera_embeddings = rearrange(camera_embeddings, "B Nv Nce -> (B Nv) Nce") + + images_cond.append(imgs_in) + num_views = len(VIEWS) + with torch.autocast("cuda"): + # B*Nv images + for guidance_scale in cfg.validation_guidance_scales: + out = pipeline( + imgs_in, camera_embeddings, generator=generator, guidance_scale=guidance_scale, output_type='pt', num_images_per_prompt=1, **cfg.pipe_validation_kwargs + ).images + + bsz = out.shape[0] // 2 + normals_pred = out[:bsz] + images_pred = out[bsz:] + + cur_dir = os.path.join(save_dir, f"cropsize-{cfg.validation_dataset.crop_size}-cfg{guidance_scale:.1f}") + + for i in range(bsz//num_views): + scene = filename[i] + scene_dir = os.path.join(cur_dir, scene) + normal_dir = os.path.join(scene_dir, "normals") + masked_colors_dir = os.path.join(scene_dir, "masked_colors") + os.makedirs(normal_dir, exist_ok=True) + os.makedirs(masked_colors_dir, exist_ok=True) + for j in range(num_views): + view = VIEWS[j] + idx = i*num_views + j + normal = normals_pred[idx] + color = images_pred[idx] + + normal_filename = f"normals_000_{view}.png" + rgb_filename = f"rgb_000_{view}.png" + normal = save_image(normal, os.path.join(normal_dir, normal_filename)) + color = save_image(color, os.path.join(scene_dir, rgb_filename)) + + rm_normal = remove(normal) + rm_color = remove(color) + + save_image_numpy(rm_normal, os.path.join(scene_dir, normal_filename)) + save_image_numpy(rm_color, os.path.join(masked_colors_dir, rgb_filename)) + + torch.cuda.empty_cache() + + +def load_wonder3d_pipeline(cfg): + + pipeline = MVDiffusionImagePipeline.from_pretrained( + cfg.pretrained_model_name_or_path, + torch_dtype=weight_dtype + ) + + # pipeline.to('cuda:0') + pipeline.unet.enable_xformers_memory_efficient_attention() + + + if torch.cuda.is_available(): + pipeline.to('cuda:0') + # sys.main_lock = threading.Lock() + return pipeline + + +def main( + cfg: TestConfig +): + + # If passed along, set the training seed now. + if cfg.seed is not None: + set_seed(cfg.seed) + + pipeline = load_wonder3d_pipeline(cfg) + + if cfg.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + print( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + pipeline.unet.enable_xformers_memory_efficient_attention() + print("use xformers.") + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # Get the dataset + validation_dataset = MVDiffusionDataset( + **cfg.validation_dataset + ) + + + # DataLoaders creation: + validation_dataloader = torch.utils.data.DataLoader( + validation_dataset, batch_size=cfg.validation_batch_size, shuffle=False, num_workers=cfg.dataloader_num_workers + ) + + + os.makedirs(cfg.save_dir, exist_ok=True) + + if cfg.pred_type == 'joint': + log_validation_joint( + validation_dataloader, + pipeline, + cfg, + weight_dtype, + 'validation', + cfg.save_dir + ) + else: + log_validation( + validation_dataloader, + pipeline, + cfg, + weight_dtype, + 'validation', + cfg.save_dir + ) + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, required=True) + args, extras = parser.parse_known_args() + + from utils.misc import load_config + + # parse YAML config to OmegaConf + cfg = load_config(args.config, cli_args=extras) + print(cfg) + schema = OmegaConf.structured(TestConfig) + # cfg = OmegaConf.load(args.config) + cfg = OmegaConf.merge(schema, cfg) + + if cfg.num_views == 6: + VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left'] + elif cfg.num_views == 4: + VIEWS = ['front', 'right', 'back', 'left'] + main(cfg) diff --git a/triton-2.0.0-cp310-cp310-win_amd64.whl b/triton-2.0.0-cp310-cp310-win_amd64.whl new file mode 100644 index 0000000000000000000000000000000000000000..e92974761fd4ade443e89067796b3780d7d0463b --- /dev/null +++ b/triton-2.0.0-cp310-cp310-win_amd64.whl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91a6ec395022743269c942df7af01b210f642fb633d146a811be05a455adbae2 +size 12643861 diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..c16fafa2ab8e7b934be711c41aed6e12001444fd --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,54 @@ +import os +from omegaconf import OmegaConf +from packaging import version + + +# ============ Register OmegaConf Recolvers ============= # +OmegaConf.register_new_resolver('calc_exp_lr_decay_rate', lambda factor, n: factor**(1./n)) +OmegaConf.register_new_resolver('add', lambda a, b: a + b) +OmegaConf.register_new_resolver('sub', lambda a, b: a - b) +OmegaConf.register_new_resolver('mul', lambda a, b: a * b) +OmegaConf.register_new_resolver('div', lambda a, b: a / b) +OmegaConf.register_new_resolver('idiv', lambda a, b: a // b) +OmegaConf.register_new_resolver('basename', lambda p: os.path.basename(p)) +# ======================================================= # + + +def prompt(question): + inp = input(f"{question} (y/n)").lower().strip() + if inp and inp == 'y': + return True + if inp and inp == 'n': + return False + return prompt(question) + + +def load_config(*yaml_files, cli_args=[]): + yaml_confs = [OmegaConf.load(f) for f in yaml_files] + cli_conf = OmegaConf.from_cli(cli_args) + conf = OmegaConf.merge(*yaml_confs, cli_conf) + OmegaConf.resolve(conf) + return conf + + +def config_to_primitive(config, resolve=True): + return OmegaConf.to_container(config, resolve=resolve) + + +def dump_config(path, config): + with open(path, 'w') as fp: + OmegaConf.save(config=config, f=fp) + +def get_rank(): + # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, + # therefore LOCAL_RANK needs to be checked first + rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") + for key in rank_keys: + rank = os.environ.get(key) + if rank is not None: + return int(rank) + return 0 + + +def parse_version(ver): + return version.parse(ver)