Add post-training folder
190
post-training/VLM-R1/.gitignore
vendored
Normal file
@@ -0,0 +1,190 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
output/
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Temp folders
|
||||
data/
|
||||
wandb/
|
||||
scripts/
|
||||
checkpoints/
|
||||
.vscode/
|
||||
.cursor/
|
||||
|
||||
rec_jsons_processed/
|
||||
|
||||
test/
|
||||
images_eval/
|
||||
*.zip
|
||||
*.txt
|
||||
runs/
|
||||
temps/
|
||||
src/open-r1-multimodal/src/open_r1/grpo_jsonl_ocr.py
|
||||
128
post-training/VLM-R1/CODE_OF_CONDUCT.md
Normal file
@@ -0,0 +1,128 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, religion, or sexual identity
|
||||
and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the
|
||||
overall community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or
|
||||
advances of any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email
|
||||
address, without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
https://x.com/OmAI_lab.
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series
|
||||
of actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interactions in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or
|
||||
permanent ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within
|
||||
the community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.0, available at
|
||||
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
||||
|
||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
||||
enforcement ladder](https://github.com/mozilla/diversity).
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
https://www.contributor-covenant.org/faq. Translations are available at
|
||||
https://www.contributor-covenant.org/translations.
|
||||
42
post-training/VLM-R1/Dockerfile
Normal file
@@ -0,0 +1,42 @@
|
||||
# Use the specified base image
|
||||
FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /workspace
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
git \
|
||||
wget \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install flash-attn using pre-built wheel
|
||||
RUN pip install --no-cache-dir \
|
||||
"https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl" || \
|
||||
pip install flash-attn --no-build-isolation
|
||||
|
||||
# Install additional required packages
|
||||
RUN pip install \
|
||||
wandb==0.18.3 \
|
||||
tensorboardx \
|
||||
qwen_vl_utils \
|
||||
torchvision \
|
||||
git+https://github.com/huggingface/transformers.git
|
||||
|
||||
# Copy local open-r1-multimodal repository
|
||||
COPY ./src/open-r1-multimodal /workspace/src/open-r1-multimodal
|
||||
|
||||
# Install open_r1
|
||||
WORKDIR /workspace/src/open-r1-multimodal
|
||||
RUN pip install -e ".[dev]"
|
||||
WORKDIR /workspace
|
||||
|
||||
# Install vllm
|
||||
RUN pip install vllm==0.7.2
|
||||
|
||||
# Set environment variables for better Python output
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Default command
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
201
post-training/VLM-R1/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
245
post-training/VLM-R1/README.md
Normal file
@@ -0,0 +1,245 @@
|
||||
# VLM-R1: A stable and generalizable R1-style Large Vision-Language Model
|
||||
|
||||
<font size=4><div align='center' > [[🤗 REC Demo](https://huggingface.co/spaces/omlab/VLM-R1-Referral-Expression)] [[🤗 OVD Demo](https://huggingface.co/spaces/omlab/VLM-R1-OVD)] [[🤗 REC Data](https://huggingface.co/datasets/omlab/VLM-R1)] [[🤗 Checkpoints](https://huggingface.co/collections/omlab/vlm-r1-models-67b7352db15c19d57157c348)] </div></font>
|
||||
|
||||
<font size=4><div align='center'>[[📄 Tech Report](https://arxiv.org/abs/2504.07615)] [[📝 Blog](https://om-ai-lab.github.io/index.html)]</div></font>
|
||||
|
||||
<div align="center">
|
||||
<img src="./assets/performance4.png" width="900"/>
|
||||
<div>
|
||||
<font size=4>
|
||||
<p>🎉 <b>Our VLM-R1 Math model reaches the top of the Open-Compass Math Leaderboard (under 4B parameters) and OVD model achieves the state-of-the-art performance on OVDEval.</b></p>
|
||||
</font>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Since the introduction of [Deepseek-R1](https://github.com/deepseek-ai/DeepSeek-R1), numerous works have emerged focusing on reproducing and improving upon it. In this project, we propose VLM-R1, a stable and generalizable R1-style Large Vision-Language Model.
|
||||
|
||||
Specifically, for the task of Referring Expression Comprehension (REC), we trained [Qwen2.5-VL](https://github.com/QwenLM/Qwen2.5-VL) using both R1 and SFT approaches. The results reveal that, on the in-domain test data, the performance of the SFT model shows little change compared to that of the R1 model base model when the number of training steps is relatively small (100–600 steps), while the R1 model shows a steady improvement (as shown at the left of the figure below). More importantly, on the out-of-domain test data, the SFT model’s performance deteriorates slightly as the number of steps increases. Nevertheless, the RL model generalizes its reasoning ability to the out-of-domain data (as shown at the right of the figure below).
|
||||
|
||||

|
||||
\* *We found previous REC SFT exps used a mismatch pixel config. Therefore, we re-run the study with the correct config on a more complex out-of-domain data. See our [findings](https://om-ai-lab.github.io/2025_03_24.html) for details.*
|
||||
|
||||
## 🚀 Features
|
||||
|
||||
This repository supports:
|
||||
|
||||
- **`Full Fine-tuning for GRPO`**: see [run_grpo_rec.sh](run_scripts/run_grpo_rec.sh)
|
||||
- **`Freeze Vision Modules`**: set `freeze_vision_modules` as `true` in the script.
|
||||
- **`LoRA Fine-tuning for GRPO`**: see [run_grpo_rec_lora.sh](run_scripts/run_grpo_rec_lora.sh)
|
||||
- **`Multi-node Training`**: see [multinode_training_demo.sh](run_scripts/multinode_training_demo.sh)
|
||||
- **`Multi-image Input Training`**: see [run_grpo_gui.sh](run_scripts/run_grpo_gui.sh)
|
||||
- **`For your own data`**: see [here](#for-your-own-data)
|
||||
- **`Various VLMs`**: see [How to add a new model](assets/add_new_model.md), now we support QwenVL and InternVL
|
||||
|
||||
## 🗞️ Update
|
||||
|
||||
- **`2025-04-16`**: Thanks @MoonHoplite for the solution, now zero2 training is supported.
|
||||
- **`2025-04-16`**: We update the codebase. Currently, we incorporate REC conducting into the [`grpo_jsonl.py`](src/open-r1-multimodal/src/open_r1/grpo_jsonl.py) for the unify implementation. Moreover, we add a parameter `is_reward_customized_from_vlm_module` to support customized reward function from the VLM module. If this is set to `true`, the reward function will be implemented in the [QwenVL2Module](src/open-r1-multimodal/src/open_r1/vlm_modules/qwen_module.py) or [InternVLModule](src/open-r1-multimodal/src/open_r1/vlm_modules/internvl_module.py) (depends on the model you use). Besides, the current training log could output more detailed information.
|
||||
- **`2025-04-11`**: 🔥🔥🔥 We release the [technical report](https://arxiv.org/abs/2504.07615) of VLM-R1, summarizing our main results and insights.
|
||||
- **`2025-04-03`**: We add the `odLength`, `weighted_sum`, and `cosine` reward used in OVD task, please refer our [blog post](https://om-ai-lab.github.io/2025_03_20.html) and [findings](https://om-ai-lab.github.io/2025_03_24.html) to the details of the reward usage and see [grpo_jsonl.py](src/open-r1-multimodal/src/open_r1/grpo_jsonl.py) for code implementation.
|
||||
- **`2025-03-24`**: 🔥 We release the [findings](https://om-ai-lab.github.io/2025_03_24.html) of VLM-R1-OVD.
|
||||
- **`2025-03-23`**: 🔥 We release the VLM-R1-OVD [model weights](https://huggingface.co/omlab/VLM-R1-Qwen2.5VL-3B-OVD-0321) and [demo](https://huggingface.co/spaces/omlab/VLM-R1-OVD), which shows the state-of-the-art performance on OVDEval. Welcome to use it.
|
||||
- **`2025-03-20`**: 🔥 We achieved SOTA results on [OVDEval](https://github.com/om-ai-lab/OVDEval) with our RL-based model, outperforming SFT baselines and specialized object detection models. Read our [blog post](https://om-ai-lab.github.io/2025_03_20.html) for details on how reinforcement learning enhances object detection performance.
|
||||
- **`2025-03-17`**: Our VLM-R1 Math model reaches the top of the [Open-Compass Math Leaderboard](https://rank.opencompass.org.cn/leaderboard-multimodal-reasoning/?m=REALTIME) (under 4B parameters). We have released the [checkpoint](https://huggingface.co/omlab/VLM-R1-Qwen2.5VL-3B-Math-0305).
|
||||
- **`2025-03-15`**: We support multi-image input data. Check the format of multi-image input [here](#for-your-own-data). We also provide an example of multi-image script [run_grpo_gui.sh](run_scripts/run_grpo_gui.sh), see [here](#for-your-own-data) for details.
|
||||
- **`2025-03-13`**: We support InternVL for GRPO. See [run_grpo_rec_internvl.sh](run_scripts/run_grpo_rec_internvl.sh) for details. The annotation json files used in InternVL are [here](https://huggingface.co/datasets/omlab/VLM-R1/resolve/main/rec_jsons_internvl.zip). If you want to add your new model, please refer to [How to add a new model](assets/add_new_model.md).
|
||||
- **`2025-03-02`**: We support LoRA Fine-tuning for GRPO. See [run_grpo_rec_lora.sh](run_scripts/run_grpo_rec_lora.sh) for details.
|
||||
- **`2025-02-27`**: We support the `number of iterations per batch` and `epsilon value for clipping` in the original GRPO algorithm with args: `--num_iterations` and `--epsilon`.
|
||||
- **`2025-02-25`**: We support multi-node training for GRPO. See [multinode_training_demo.sh](run_scripts/multinode_training_demo.sh) for details.
|
||||
- **`2025-02-21`**: We release the [checkpoint](https://huggingface.co/omlab/Qwen2.5VL-3B-VLM-R1-REC-500steps) of the VLM-R1 REC model.
|
||||
- **`2025-02-20`**: We release the script for [general data loading](#for-your-own-data).
|
||||
- **`2025-02-19`**: We incorporate an explanation of the [SFT](#sft) method.
|
||||
- **`2025-02-17`**: We release the VLM-R1 REC [Demo](https://huggingface.co/spaces/omlab/VLM-R1-Referral-Expression) on Hugging Face Spaces.
|
||||
- **`2025-02-15`**: We release the VLM-R1 repository and [GRPO](#grpo) training script.
|
||||
|
||||
## 🤖 Models
|
||||
|
||||
- **[`OVD`](https://huggingface.co/omlab/VLM-R1-Qwen2.5VL-3B-OVD-0321)**: Trained with VLM-R1, our Open-Vocabulary Detection (OVD) model achieves the state-of-the-art performance on OVDEval.
|
||||
- **[`Math`](https://huggingface.co/omlab/VLM-R1-Qwen2.5VL-3B-Math-0305)**: Through VLM-R1 training, our math model focuses on multimodal reasoning tasks and has achieved Top1 on the OpenCompass Multi-modal Reasoning Leaderboard among models < 4B.
|
||||
- **[`REC`](https://huggingface.co/omlab/Qwen2.5VL-3B-VLM-R1-REC-500steps)**: Trained with VLM-R1, our Referring Expression Comprehension (REC) model showcases the superior performance on out-of-domain data and a series of reasoning-grounding tasks.
|
||||
|
||||
| Version | Base VLM | Checkpoint | Task Type |
|
||||
| -------------------------------- | ------------ | ---------------------------------------------------------------------------------------------------- | ------------------------- |
|
||||
| VLM-R1-Qwen2.5VL-3B-OVD-0321 | Qwen2.5VL-3B | [omlab/VLM-R1-Qwen2.5VL-3B-OVD-0321](https://huggingface.co/omlab/VLM-R1-Qwen2.5VL-3B-OVD-0321) | Open-Vocabulary Detection |
|
||||
| VLM-R1-Qwen2.5VL-3B-Math-0305 | Qwen2.5VL-3B | [omlab/VLM-R1-Qwen2.5VL-3B-Math-0305](https://huggingface.co/omlab/VLM-R1-Qwen2.5VL-3B-Math-0305) | Multi-Modal Math |
|
||||
| VLM-R1-Qwen2.5VL-3B-REC-500steps | Qwen2.5VL-3B | [omlab/Qwen2.5VL-3B-VLM-R1-REC-500steps](https://huggingface.co/omlab/Qwen2.5VL-3B-VLM-R1-REC-500steps) | REC/Reasoning-Grounding |
|
||||
|
||||
## 🎯 ToDo
|
||||
|
||||
- [X] Implement multi-node training.
|
||||
- [X] Implement LoRA Fine-tuning.
|
||||
- [X] Support more Multimodal LLMs.
|
||||
- [X] Support multi-image input.
|
||||
- [X] Release the VLM-R1 Math model.
|
||||
- [X] Release the blog of VLM-R1.
|
||||
- [X] Release the VLM-R1-OVD model.
|
||||
- [X] Release the technical report of VLM-R1.
|
||||
- [ ] Study cross task generalization.
|
||||
- [ ] Enhance VLM for other tasks [welcome issue].
|
||||
|
||||
## 🛠️ Setup
|
||||
|
||||
```bash
|
||||
conda create -n vlm-r1 python=3.10
|
||||
conda activate vlm-r1
|
||||
bash setup.sh
|
||||
```
|
||||
|
||||
## 💪🏻 Training
|
||||
|
||||
### Referring Expression Comprehension (REC)
|
||||
|
||||
#### 📚 GRPO
|
||||
|
||||
1. Download the [COCO Train2014 image](https://huggingface.co/datasets/omlab/VLM-R1/resolve/main/train2014.zip) and unzip it, and we refer to the image dir as `<your_image_root>`.
|
||||
2. Download the [RefCOCO/+/g and LISA-Grounding Annotation files](https://huggingface.co/datasets/omlab/VLM-R1/resolve/main/rec_jsons_processed.zip) and unzip it (LISA-Grounding is used for out-of-domain evaluation).
|
||||
3. Change the `data_paths` and `image_folders` in the [run_scripts/run_grpo_rec.sh](run_scripts/run_grpo_rec.sh) file.
|
||||
|
||||
```bash
|
||||
# These jsonl files are included in the annotation files at step 2.
|
||||
# Note: please use jsonl files instead of json files.
|
||||
data_paths="path/to/refcoco_train.jsonl:path/to/refcocop_train.jsonl:path/to/refcocog_train.jsonl"
|
||||
image_folders="path/to/coco:path/to/coco:path/to/coco"
|
||||
```
|
||||
|
||||
4. ``bash run_scripts/run_grpo_rec.sh``
|
||||
|
||||
> [!NOTE]
|
||||
> If you encounter 'CUDA out of memory' error, you can try to reduce the `per_device_train_batch_size`.
|
||||
|
||||
<div align="center">
|
||||
<img src="./assets/iou.jpg" width="750"/>
|
||||
</div>
|
||||
<!--  -->
|
||||
|
||||
#### 📚 Multi-Node GRPO
|
||||
|
||||
For multi-node training, please refers to [multinode_training_demo.sh](src/open-r1-multimodal/multinode_training_demo.sh).
|
||||
|
||||
#### 📚 SFT
|
||||
|
||||
We use [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory) to train the SFT model.
|
||||
|
||||
1. Clone the [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory) repository and install the dependencies.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e ".[torch,metrics]"
|
||||
```
|
||||
|
||||
2. Download the dataset_info.json, mllm_rec_json.json, and qwen2_5_vl_full_sft.yaml we provided [here](https://huggingface.co/datasets/omlab/VLM-R1/tree/main/sft_related). Put the json files in the `LLaMA-Factory/data` directory and the yaml file in the `LLaMA-Factory/examples/train_full` directory.
|
||||
3. Run the following command to train the SFT model.
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_full/qwen2_5_vl_full_sft.yaml
|
||||
```
|
||||
|
||||
### For your own data
|
||||
|
||||
<div style="text-align: justify;">
|
||||
|
||||
We support data loading the jsonl data of this format in [`src/open-r1-multimodal/src/open_r1/grpo_jsonl.py`](src/open-r1-multimodal/src/open_r1/grpo_jsonl.py). Please note that you may need to use different reward functions for your specialized tasks. Welcome to PR to add your own reward functions or share any other interesting findings!
|
||||
|
||||
</div>
|
||||
|
||||
The jsonl has the format as follows:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": 1,
|
||||
"image": "Clevr_CoGenT_TrainA_R1/data/images/CLEVR_trainA_000001_16885.png",
|
||||
"conversations": [
|
||||
{"from": "human", "value": "<image>What number of purple metallic balls are there?"},
|
||||
{"from": "gpt", "value": "0"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
If you want to use multi-image input, you can use the following format:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": 1,
|
||||
"image": ["Clevr_CoGenT_TrainA_R1/data/images/CLEVR_trainA_000001_16885.png", "Clevr_CoGenT_TrainA_R1/data/images/CLEVR_trainA_000001_16886.png"],
|
||||
"conversations": [
|
||||
{"from": "human", "value": "<image><image>What number of purple metallic balls in total within the two images?"},
|
||||
{"from": "gpt", "value": "3"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The image path in the jsonl file should be relative to the image folder specified in `--image_folders`. The absolute path of the input image is constructed as `os.path.join(image_folder, data['image'])`. For example:
|
||||
|
||||
- If your jsonl has `"image": "folder1/image1.jpg"`
|
||||
- And you specify `--image_folders "/path/to/images/"`
|
||||
- The full image path will be `/path/to/images/folder1/image1.jpg`
|
||||
|
||||
Multiple data files and image folders can be specified using ":" as a separator:
|
||||
|
||||
```bash
|
||||
--data_file_paths /path/to/data1.jsonl:/path/to/data2.jsonl \
|
||||
--image_folders /path/to/images1/:/path/to/images2/
|
||||
```
|
||||
|
||||
The script can be run like this:
|
||||
|
||||
```bash
|
||||
# You could refer to the run_grpo_rec.sh for the example
|
||||
torchrun --nproc_per_node="8" \
|
||||
--nnodes="1" \
|
||||
--node_rank="0" \
|
||||
--master_addr="127.0.0.1" \
|
||||
--master_port="12345" \
|
||||
src/open_r1/grpo_jsonl.py \
|
||||
--output_dir output/$RUN_NAME \
|
||||
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--deepspeed ${REPO_HOME}/src/open-r1-multimodal/local_scripts/zero3.json \
|
||||
--data_file_paths /path/to/your/data.jsonl \ # can be multiple, separated by ":"
|
||||
--image_folders /path/to/your/image/folder \ # can be multiple, separated by ":"
|
||||
...
|
||||
```
|
||||
|
||||
<div style="text-align: justify;">
|
||||
|
||||
### Multi-image Input
|
||||
We provide an example of multi-image script [run_grpo_gui.sh](src/open-r1-multimodal/run_scripts/run_grpo_gui.sh). This task requires the model to analyze two GUI screenshots, taken before and after a user action, to determine if any UI interaction defects are present, which is from [GUI-Testing-Arena](https://huggingface.co/datasets/songjah/GTArena-UI-Defects). Download the [image](https://huggingface.co/datasets/omlab/VLM-R1/resolve/main/gui_multi-image.zip) and unzip it into the `/path/to/images/`. Then modify the `image_folders` parameter in the script and run it.
|
||||
|
||||
```bash
|
||||
bash run_scripts/run_grpo_gui.sh
|
||||
```
|
||||
|
||||
</div>
|
||||
|
||||
## 📊 Evaluation
|
||||
|
||||

|
||||
|
||||
1. Download the provided [LISA-Grounding images](https://huggingface.co/datasets/omlab/VLM-R1/resolve/main/lisa-test.zip).
|
||||
|
||||
```bash
|
||||
cd ./src/eval
|
||||
|
||||
# Remember to change the model path, image root, and annotation path in the script
|
||||
torchrun --nproc_per_node=X test_rec_r1.py # for GRPO. 'X' is the number of GPUs you have.
|
||||
torchrun --nproc_per_node=X test_rec_baseline.py # for SFT.
|
||||
```
|
||||
|
||||
## 🤝 Acknowledgements
|
||||
|
||||
We would like to express our sincere gratitude to [DeepSeek](https://github.com/deepseek-ai/DeepSeek-R1), [Open-R1](https://github.com/huggingface/open-r1), [QwenVL](https://github.com/QwenLM/Qwen2.5-VL), [Open-R1-Multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal), [R1-V](https://github.com/Deep-Agent/R1-V), [RefCOCO](https://github.com/lichengunc/refer), [RefGTA](https://github.com/mikittt/easy-to-understand-REG/tree/master/pyutils/refer2), [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory), [OVDEval](https://github.com/om-ai-lab/OVDEval), [GUI-Testing-Arena](https://huggingface.co/datasets/songjah/GTArena-UI-Defects), and [LISA](https://github.com/dvlab-research/LISA) for providing open-source resources that contributed to the development of this project.
|
||||
|
||||
## ⭐️ Citation
|
||||
|
||||
If you find this project useful, welcome to cite us.
|
||||
|
||||
```bib
|
||||
@article{shen2025vlm,
|
||||
title={Vlm-r1: A stable and generalizable r1-style large vision-language model},
|
||||
author={Shen, Haozhan and Liu, Peng and Li, Jingcheng and Fang, Chunxin and Ma, Yibo and Liao, Jiajia and Shen, Qiaoli and Zhang, Zilun and Zhao, Kangjia and Zhang, Qianqian and Xu, Ruochen and Zhao, Tiancheng },
|
||||
journal={arXiv preprint arXiv:2504.07615},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
50
post-training/VLM-R1/assets/add_new_model.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# How to add a new model
|
||||
|
||||
## VLM Module
|
||||
|
||||
<div align=center>
|
||||
<img width="70%" src="module.png"/>
|
||||
</div>
|
||||
|
||||
To enhance scalability and ease of integration for new models, we create the [VLM Module class](../src/open-r1-multimodal/src/open_r1/vlm_modules/vlm_module.py). As shown in the figure above, The current **GRPO Trainer** primarily handles abstract operations, such as "placing the question into the chat template" and "converting the image and prompt into input_ids". The actual implementation is delegated to the **VLM Module**, while the **GRPO Trainer** is responsible solely for calling the exposed function interfaces of the **VLM Module**.
|
||||
|
||||
## The implemented function of VLM Module
|
||||
|
||||
To add a new model, you need to implement the following functions in the **VLM Module**:
|
||||
### 1. get_vlm_key
|
||||
Return the identifier of the model, such as "internvl", "qwen".
|
||||
|
||||
### 2. get_model_class
|
||||
Return the model class of the model that is used to initialize in the `GRPO Trainer`. For "qwen", the model class is `Qwen2_5_VLForConditionalGeneration` or `Qwen2VLForConditionalGeneration`, and for "internvl", the model class is `InternVLChatModel`.
|
||||
|
||||
### 3. post_model_init
|
||||
This function is called after the model and processor are initialized. You can do some post-processing here. Taking "internvl" as an example, we need to record the `conv_template` and `num_image_token` for later use, and set the `img_context_token_id` for the model.
|
||||
|
||||
### 4. is_embeds_input
|
||||
Return whether the model accepts `input_embedding` as input while not `input_ids` when calling `generate` method.
|
||||
|
||||
### 5. get_processing_class
|
||||
Return the processing class of the model. For most models, `AutoProcessor` is typically used.
|
||||
|
||||
### 6. get_vision_modules_keywords
|
||||
Return the keywords of the vision modules of the model. This is used to freeze the vision modules in the `GRPO Trainer`.
|
||||
|
||||
### 7. get_custom_multimodal_keywords
|
||||
Besides `input_ids` and `attention_mask`, the model also accepts some distinct custom multimodal inputs for different VLMs when calling `forward` method, such as `pixel_values` and `image_thw` for "qwen", and `pixel_values` and `image_flags` for "internvl".
|
||||
|
||||
### 8. get_non_generate_params
|
||||
There may be some parameters in the custom multimodal inputs that are not used in the `generate` method, such as `image_flags` for "internvl". You need to return them in the `get_non_generate_params` function.
|
||||
|
||||
### 9. get_custom_processing_keywords
|
||||
Some models may have some specific parameters for the `processing_class`, such as `max_pixels` and `min_pixels` for "qwen", and `max_anyres_num` for "internvl". You need to return them in the `get_custom_processing_keywords` function.
|
||||
|
||||
### 10. prepare_prompt
|
||||
This function is used to place the prompt into the chat template. Different models may have different processing methods, so you need to implement this function according to the model.
|
||||
|
||||
### 11. prepare_model_inputs
|
||||
This function is used to process the image and prompt into the format that the model accepts. The returned value should be a `dict` with the following keys: `input_ids`, `attention_mask`, and the custom multimodal inputs.
|
||||
|
||||
#### You could refer to [qwen_module.py](../src/open-r1-multimodal/src/open_r1/vlm_modules/qwen_module.py) and [internvl_module.py](../src/open-r1-multimodal/src/open_r1/vlm_modules/internvl_module.py) for the example implementations of QwenVL and InternVL respectively.
|
||||
|
||||
|
||||
|
||||
BIN
post-training/VLM-R1/assets/data2.png
Normal file
|
After Width: | Height: | Size: 2.5 MiB |
BIN
post-training/VLM-R1/assets/iou.jpg
Normal file
|
After Width: | Height: | Size: 58 KiB |
BIN
post-training/VLM-R1/assets/math-leaderboard.jpg
Normal file
|
After Width: | Height: | Size: 735 KiB |
BIN
post-training/VLM-R1/assets/module.png
Normal file
|
After Width: | Height: | Size: 682 KiB |
BIN
post-training/VLM-R1/assets/performance3.png
Normal file
|
After Width: | Height: | Size: 473 KiB |
BIN
post-training/VLM-R1/assets/performance4.png
Normal file
|
After Width: | Height: | Size: 2.0 MiB |
BIN
post-training/VLM-R1/assets/wandb.jpg
Normal file
|
After Width: | Height: | Size: 247 KiB |
BIN
post-training/VLM-R1/placeholder.jpg
Normal file
|
After Width: | Height: | Size: 7.9 KiB |
109
post-training/VLM-R1/run_exp.sh
Executable file
@@ -0,0 +1,109 @@
|
||||
# 1. Export paths
|
||||
cd src/open-r1-multimodal
|
||||
export DEBUG_MODE="true"
|
||||
export PROJ_ROOT="$HOME_ROOT/code_mllm"
|
||||
RUN_NAME="Qwen2.5-VL-7B-GRPO-websight"
|
||||
export LOG_PATH="./debug_log_$RUN_NAME.txt"
|
||||
export PLACEHOLDER_PATH="$PROJ_ROOT/VLM-R1/placeholder.jpg"
|
||||
export CSS_PATH="$PROJ_ROOT/VLM-R1/tailwind.min.css"
|
||||
image_folder="$PROJ_ROOT/LLaMA-Factory/data"
|
||||
data_file_paths="$PROJ_ROOT/LLaMA-Factory/data/CodeMLLM/websight/train_rl.json"
|
||||
|
||||
# 2. Experiment parameters
|
||||
model_name="Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
gpu_num="8"
|
||||
bs_per_device=1
|
||||
num_generations=8 # assert (bs_per_device x gpu_num) % num_generations == 0
|
||||
resume="True"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# 1. Ensure PROJ_ROOT is set
|
||||
if [[ -z "${PROJ_ROOT:-}" ]]; then
|
||||
echo "ERROR: PROJ_ROOT is not defined." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
|
||||
# 3. Declare expected type for each
|
||||
declare -A expected=(
|
||||
[PLACEHOLDER_PATH]=file
|
||||
[CSS_PATH]=file
|
||||
[image_folder]=dir
|
||||
[data_file_paths]=file
|
||||
)
|
||||
|
||||
# 4. Test existence
|
||||
all_good=true
|
||||
for var in "${!expected[@]}"; do
|
||||
path="${!var}"
|
||||
type="${expected[$var]}"
|
||||
case "$type" in
|
||||
file)
|
||||
if [[ ! -f "$path" ]]; then
|
||||
echo "✗ File missing: $var → $path" >&2
|
||||
all_good=false
|
||||
else
|
||||
echo "✔ File exists: $var → $path"
|
||||
fi
|
||||
;;
|
||||
dir)
|
||||
if [[ ! -d "$path" ]]; then
|
||||
echo "✗ Directory missing: $var → $path" >&2
|
||||
all_good=false
|
||||
else
|
||||
echo "✔ Directory exists: $var → $path"
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
echo "WARNING: Unknown type for $var: $type" >&2
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# 5. Exit non-zero if any missing
|
||||
if ! $all_good; then
|
||||
echo "One or more paths were missing." >&2
|
||||
exit 2
|
||||
fi
|
||||
|
||||
echo "All paths verified successfully."
|
||||
|
||||
|
||||
torchrun --nproc_per_node=$gpu_num \
|
||||
--nnodes="1" \
|
||||
--node_rank="0" \
|
||||
--master_addr="127.0.0.1" \
|
||||
--master_port="12346" \
|
||||
src/open_r1/grpo_jsonl.py \
|
||||
--deepspeed local_scripts/zero3.json \
|
||||
--output_dir $PROJ_ROOT/VLM-R1/output/$RUN_NAME \
|
||||
--model_name_or_path $model_name \
|
||||
--dataset_name none \
|
||||
--image_folders $image_folder\
|
||||
--data_file_paths $data_file_paths \
|
||||
--freeze_vision_modules true \
|
||||
--max_pixels 1843200 \
|
||||
--max_prompt_length 4096 \
|
||||
--max_completion_length 2048 \
|
||||
--num_generations $num_generations \
|
||||
--per_device_train_batch_size $bs_per_device \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--logging_steps 1 \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--data_seed 42 \
|
||||
--report_to wandb \
|
||||
--gradient_checkpointing true \
|
||||
--attn_implementation flash_attention_2 \
|
||||
--num_train_epochs 2 \
|
||||
--run_name $RUN_NAME \
|
||||
--save_steps 100 \
|
||||
--save_only_model true \
|
||||
--resume_from_checkpoint $resume \
|
||||
@@ -0,0 +1,21 @@
|
||||
output_dir: /path/to/output/runs/Qwen2.5-VL-3B-Idefics-V3-RSN-ai2d-500steps
|
||||
model_name_or_path: /path/to/models/Qwen2.5-VL-3B-Instruct
|
||||
dataset_name: Idefics-ai2d
|
||||
data_file_paths: /path/to/data/ai2d.jsonl
|
||||
image_folders: /path/to/images
|
||||
max_prompt_length: 1024
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
logging_steps: 1
|
||||
bf16: true
|
||||
report_to: wandb
|
||||
gradient_checkpointing: false
|
||||
deepspeed: /path/to/config/zero3.json
|
||||
attn_implementation: flash_attention_2
|
||||
max_pixels: 401408
|
||||
max_steps: 500
|
||||
run_name: Qwen2.5-VL-3B-Idefics-V3-RSN-ai2d-500steps-multinode
|
||||
save_steps: 100
|
||||
save_total_limit: 3
|
||||
save_only_model: true
|
||||
num_generations: 8
|
||||
145
post-training/VLM-R1/run_scripts/multinode_training_demo.sh
Normal file
@@ -0,0 +1,145 @@
|
||||
#!/bin/bash
|
||||
|
||||
RUN_NAME=multinode_training # assume there is a ${RUN_NAME}_args.yaml file in the current directory
|
||||
|
||||
declare -A node2ip_map
|
||||
node2ip_map=(
|
||||
["node1"]="192.168.1.101"
|
||||
["node2"]="192.168.1.102"
|
||||
["node3"]="192.168.1.103"
|
||||
["node4"]="192.168.1.104"
|
||||
)
|
||||
|
||||
# Default nodes if no arguments provided
|
||||
DEFAULT_NODES=("node1" "node2")
|
||||
|
||||
# Local codebase path in file system
|
||||
LOCAL_CODEBASE_PATH="/path/to/your/codebase"
|
||||
|
||||
# Use provided nodes or default nodes
|
||||
if [ "$#" -ge 1 ]; then
|
||||
NODES=("$@")
|
||||
else
|
||||
NODES=("${DEFAULT_NODES[@]}")
|
||||
echo "Using default nodes: ${NODES[*]}"
|
||||
fi
|
||||
|
||||
# Add this debug line
|
||||
echo "All nodes in order: ${NODES[@]}"
|
||||
|
||||
TOTAL_NODES=${#NODES[@]}
|
||||
MASTER_NODE=${NODES[0]}
|
||||
MASTER_PORT=12345
|
||||
|
||||
# Get project root directory (using the directory where this script is located)
|
||||
PROJECT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
echo "Project root directory: $PROJECT_ROOT"
|
||||
|
||||
# Get master node IP address
|
||||
echo "MASTER_NODE: $MASTER_NODE"
|
||||
MASTER_IP="${node2ip_map[$MASTER_NODE]}"
|
||||
echo "Master node IP: $MASTER_IP"
|
||||
|
||||
# Create log directory for each node
|
||||
LOG_DIR="path/to/your/log/dir"
|
||||
mkdir -p $LOG_DIR
|
||||
|
||||
# Generate docker-compose.yml
|
||||
echo "Generating docker-compose.yml..."
|
||||
cat > docker-compose.yml << EOL
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
trainer:
|
||||
image: your/training-image:tag
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
shm_size: '8gb'
|
||||
volumes:
|
||||
- /path/to/data:/data
|
||||
- $LOCAL_CODEBASE_PATH/src:/workspace/src
|
||||
environment:
|
||||
- MASTER_ADDR=\${MASTER_ADDR:-$MASTER_IP}
|
||||
- MASTER_PORT=\${MASTER_PORT:-12345}
|
||||
- NODE_RANK=\${NODE_RANK:-0}
|
||||
- WORLD_SIZE=\${WORLD_SIZE:-4}
|
||||
- DEBUG_MODE=true
|
||||
- LOG_PATH=${LOG_DIR}/debug_log.txt
|
||||
- WANDB_API_KEY=your_wandb_api_key # Optional: for logging with weights & biases
|
||||
- WANDB_PROJECT=your_project_name
|
||||
- WANDB_RUN_NAME=${RUN_NAME}-$(date +%Y-%m-%d-%H-%M-%S)
|
||||
- PYTHONPATH=/workspace/src
|
||||
network_mode: "host"
|
||||
command: /bin/bash
|
||||
working_dir: /workspace
|
||||
EOL
|
||||
|
||||
# Function to build training arguments from yaml
|
||||
build_train_args() {
|
||||
args=""
|
||||
while IFS=": " read -r key value; do
|
||||
[[ -z "$key" || "$key" =~ ^[[:space:]]*# ]] && continue
|
||||
value=$(echo "$value" | sed -e 's/^[[:space:]]*//' -e 's/[[:space:]]*$//' -e 's/^"//' -e 's/"$//')
|
||||
if [[ "$value" == "true" ]]; then
|
||||
args="$args --$key"
|
||||
elif [[ "$value" == "false" ]]; then
|
||||
continue
|
||||
else
|
||||
args="$args --$key $value"
|
||||
fi
|
||||
done < ${RUN_NAME}_args.yaml
|
||||
echo "$args"
|
||||
}
|
||||
|
||||
# Get training arguments
|
||||
TRAIN_ARGS=$(build_train_args)
|
||||
echo "TRAIN_ARGS: $TRAIN_ARGS"
|
||||
|
||||
# Launch containers on each node
|
||||
NODE_RANK=0
|
||||
for host in "${NODES[@]}"; do
|
||||
LOG_FILE="$LOG_DIR/${host}_rank${NODE_RANK}.log"
|
||||
if [ "$host" = "$MASTER_NODE" ]; then
|
||||
echo "Launching on master $host with rank $NODE_RANK, logging to $LOG_FILE"
|
||||
ssh $host "cd $PROJECT_ROOT && \
|
||||
MASTER_ADDR=$MASTER_IP \
|
||||
NODE_RANK=$NODE_RANK \
|
||||
WORLD_SIZE=$TOTAL_NODES \
|
||||
sudo -E docker-compose -f docker-compose.yml run --rm trainer \
|
||||
torchrun --nproc_per_node=8 \
|
||||
--nnodes=$TOTAL_NODES \
|
||||
--node_rank=$NODE_RANK \
|
||||
--master_addr=$MASTER_IP \
|
||||
--master_port=$MASTER_PORT \
|
||||
src/train.py \
|
||||
$TRAIN_ARGS" > "$LOG_FILE" 2>&1 &
|
||||
else
|
||||
echo "Launching on $host with rank $NODE_RANK, logging to $LOG_FILE"
|
||||
ssh $host "cd $PROJECT_ROOT && \
|
||||
MASTER_ADDR=$MASTER_IP \
|
||||
NODE_RANK=$NODE_RANK \
|
||||
WORLD_SIZE=$TOTAL_NODES \
|
||||
sudo -E docker-compose -f docker-compose.yml run --rm trainer \
|
||||
torchrun --nproc_per_node=8 \
|
||||
--nnodes=$TOTAL_NODES \
|
||||
--node_rank=$NODE_RANK \
|
||||
--master_addr=$MASTER_IP \
|
||||
--master_port=$MASTER_PORT \
|
||||
src/train.py \
|
||||
$TRAIN_ARGS" > "$LOG_FILE" 2>&1 &
|
||||
fi
|
||||
|
||||
NODE_RANK=$((NODE_RANK + 1))
|
||||
done
|
||||
|
||||
echo "Jobs launched. To monitor the logs, you can:"
|
||||
echo "1. Use 'tail -f $LOG_DIR/*.log' to watch all logs"
|
||||
echo "2. Use 'tail -f $LOG_DIR/<node_name>_rank<N>.log' to watch a specific node"
|
||||
|
||||
# Wait for all background processes to complete
|
||||
wait
|
||||
58
post-training/VLM-R1/run_scripts/run_grpo_gui.sh
Normal file
@@ -0,0 +1,58 @@
|
||||
PROJECT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )"
|
||||
export REPO_HOME="${PROJECT_ROOT}" # TODO: change this to your own
|
||||
echo "REPO_HOME: $REPO_HOME"
|
||||
# on remote
|
||||
data_paths="${REPO_HOME}/src/open-r1-multimodal/data_jsonl/gui_multi-image.jsonl"
|
||||
image_folders="/data9/shz/project/vlm-r1/VLM-R1/images/gui_multi-image"
|
||||
model_path="/data9/shz/ckpt/Qwen2.5-VL-3B-Instruct"
|
||||
is_reward_customized_from_vlm_module=False
|
||||
reward_methods="all_match"
|
||||
echo "data_paths: $data_paths"
|
||||
echo "image_folders: $image_folders"
|
||||
|
||||
export EXP_NAME="GUI-multi-image" # TODO: change this to your own experiment name
|
||||
TASK_TYPE="gui"
|
||||
cd ${REPO_HOME}/src/open-r1-multimodal
|
||||
|
||||
export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
|
||||
# create the run directory and log file
|
||||
mkdir -p ${REPO_HOME}/runs/${EXP_NAME}/log
|
||||
export LOG_PATH="${REPO_HOME}/runs/${EXP_NAME}/log/debug_log.$(date +%Y-%m-%d-%H-%M-%S).txt"
|
||||
MAX_STEPS=1200 # TODO: change this to your own max steps
|
||||
|
||||
# export WANDB_DISABLED=true
|
||||
# CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6
|
||||
torchrun --nproc_per_node="8" \
|
||||
--nnodes="1" \
|
||||
--node_rank="0" \
|
||||
--master_addr="127.0.0.1" \
|
||||
--master_port="12349" \
|
||||
src/open_r1/grpo_jsonl.py \
|
||||
--use_vllm False \
|
||||
--output_dir ${REPO_HOME}/checkpoints/rl/${EXP_NAME} \
|
||||
--resume_from_checkpoint True \
|
||||
--model_name_or_path $model_path \
|
||||
--data_file_paths $data_paths \
|
||||
--image_folders $image_folders \
|
||||
--is_reward_customized_from_vlm_module $is_reward_customized_from_vlm_module \
|
||||
--reward_method $reward_methods \
|
||||
--task_type $TASK_TYPE \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--gradient_checkpointing true \
|
||||
--logging_steps 1 \
|
||||
--num_train_epochs 2 \
|
||||
--max_steps $MAX_STEPS \
|
||||
--bf16 \
|
||||
--attn_implementation flash_attention_2 \
|
||||
--run_name ${EXP_NAME} \
|
||||
--save_steps 400 \
|
||||
--num_generations 8 \
|
||||
--max_completion_length 2048 \
|
||||
--reward_funcs accuracy format \
|
||||
--beta 0.04 \
|
||||
--report_to wandb \
|
||||
--dataset-name not_used \
|
||||
--deepspeed ${REPO_HOME}/src/open-r1-multimodal/local_scripts/zero3.json \
|
||||
|
||||
echo "Training completed for ${EXP_NAME}"
|
||||
57
post-training/VLM-R1/run_scripts/run_grpo_rec.sh
Normal file
@@ -0,0 +1,57 @@
|
||||
PROJECT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )"
|
||||
export REPO_HOME="${PROJECT_ROOT}"
|
||||
echo "REPO_HOME: $REPO_HOME"
|
||||
# Change the data_paths and image_folders to your own data
|
||||
data_paths="/training/shz/dataset/vlm-r1/rec_jsonsl_train/refcoco_train.jsonl:/training/shz/dataset/vlm-r1/rec_jsonsl_train/refcocop_train.jsonl:/training/shz/dataset/vlm-r1/rec_jsonsl_train/refcocog_train.jsonl"
|
||||
image_folders="/training/shz/dataset/coco:/training/shz/dataset/coco:/training/shz/dataset/coco"
|
||||
model_path="/training/models/Qwen2.5-VL-3B-Instruct"
|
||||
is_reward_customized_from_vlm_module=True
|
||||
echo "data_paths: $data_paths"
|
||||
echo "image_folders: $image_folders"
|
||||
|
||||
export EXP_NAME="Qwen2.5-VL-3B-Instruct-rec" # TODO: change this to your own experiment name
|
||||
TASK_TYPE="rec"
|
||||
cd ${REPO_HOME}/src/open-r1-multimodal
|
||||
|
||||
export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
|
||||
# create the run directory and log file
|
||||
mkdir -p ${REPO_HOME}/runs/${EXP_NAME}/log
|
||||
export LOG_PATH="${REPO_HOME}/runs/${EXP_NAME}/log/debug_log.$(date +%Y-%m-%d-%H-%M-%S).txt"
|
||||
# MAX_STEPS=1200 # TODO: change this to your own max steps
|
||||
|
||||
|
||||
# export WANDB_DISABLED=true
|
||||
# CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6
|
||||
torchrun --nproc_per_node="8" \
|
||||
--nnodes="1" \
|
||||
--node_rank="0" \
|
||||
--master_addr="127.0.0.1" \
|
||||
--master_port="12349" \
|
||||
src/open_r1/grpo_jsonl.py \
|
||||
--use_vllm False \
|
||||
--output_dir ${REPO_HOME}/checkpoints/rl/${EXP_NAME} \
|
||||
--resume_from_checkpoint True \
|
||||
--model_name_or_path $model_path \
|
||||
--data_file_paths $data_paths \
|
||||
--image_folders $image_folders \
|
||||
--is_reward_customized_from_vlm_module $is_reward_customized_from_vlm_module \
|
||||
--task_type $TASK_TYPE \
|
||||
--per_device_train_batch_size 8 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--gradient_checkpointing true \
|
||||
--logging_steps 1 \
|
||||
--num_train_epochs 2 \
|
||||
--bf16 \
|
||||
--attn_implementation flash_attention_2 \
|
||||
--run_name ${EXP_NAME} \
|
||||
--data_seed 42 \
|
||||
--save_steps 100 \
|
||||
--num_generations 8 \
|
||||
--max_completion_length 2048 \
|
||||
--reward_funcs accuracy format \
|
||||
--beta 0.04 \
|
||||
--report_to wandb \
|
||||
--dataset-name this_is_not_used \
|
||||
--deepspeed ${REPO_HOME}/src/open-r1-multimodal/local_scripts/zero3.json \
|
||||
|
||||
echo "Training completed for ${EXP_NAME}"
|
||||
58
post-training/VLM-R1/run_scripts/run_grpo_rec_internvl.sh
Normal file
@@ -0,0 +1,58 @@
|
||||
PROJECT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )"
|
||||
export REPO_HOME="${PROJECT_ROOT}"
|
||||
echo "REPO_HOME: $REPO_HOME"
|
||||
# on remote
|
||||
data_paths="/training/shz/dataset/vlm-r1/rec_jsonsl_train/refcoco_train.jsonl:/training/shz/dataset/vlm-r1/rec_jsonsl_train/refcocop_train.jsonl:/training/shz/dataset/vlm-r1/rec_jsonsl_train/refcocog_train.jsonl"
|
||||
image_folders="/training/shz/dataset/coco:/training/shz/dataset/coco:/training/shz/dataset/coco"
|
||||
model_path="OpenGVLab/InternVL2_5-4B-MPO"
|
||||
is_reward_customized_from_vlm_module=True
|
||||
echo "data_paths: $data_paths"
|
||||
echo "image_folders: $image_folders"
|
||||
|
||||
export EXP_NAME="InternVL2_5-4B_MPO-rec" # TODO: change this to your own experiment name
|
||||
TASK_TYPE="rec"
|
||||
cd ${REPO_HOME}/src/open-r1-multimodal
|
||||
|
||||
export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
|
||||
# create the run directory and log file
|
||||
mkdir -p ${REPO_HOME}/runs/${EXP_NAME}/log
|
||||
export LOG_PATH="${REPO_HOME}/runs/${EXP_NAME}/log/debug_log.$(date +%Y-%m-%d-%H-%M-%S).txt"
|
||||
# MAX_STEPS=1200 # TODO: change this to your own max steps
|
||||
|
||||
|
||||
# export WANDB_DISABLED=true
|
||||
# CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6
|
||||
torchrun --nproc_per_node="8" \
|
||||
--nnodes="1" \
|
||||
--node_rank="0" \
|
||||
--master_addr="127.0.0.1" \
|
||||
--master_port="12349" \
|
||||
src/open_r1/grpo_jsonl.py \
|
||||
--use_vllm False \
|
||||
--output_dir ${REPO_HOME}/checkpoints/rl/${EXP_NAME} \
|
||||
--resume_from_checkpoint True \
|
||||
--model_name_or_path $model_path \
|
||||
--data_file_paths $data_paths \
|
||||
--image_folders $image_folders \
|
||||
--is_reward_customized_from_vlm_module $is_reward_customized_from_vlm_module \
|
||||
--task_type $TASK_TYPE \
|
||||
--max_anyres_num 6 \
|
||||
--per_device_train_batch_size 8 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--gradient_checkpointing true \
|
||||
--logging_steps 1 \
|
||||
--num_train_epochs 2 \
|
||||
--bf16 \
|
||||
--attn_implementation flash_attention_2 \
|
||||
--run_name ${EXP_NAME} \
|
||||
--data_seed 42 \
|
||||
--save_steps 100 \
|
||||
--num_generations 8 \
|
||||
--max_completion_length 2048 \
|
||||
--reward_funcs accuracy format \
|
||||
--beta 0.04 \
|
||||
--report_to wandb \
|
||||
--dataset-name this_is_not_used \
|
||||
--deepspeed ${REPO_HOME}/src/open-r1-multimodal/local_scripts/zero3.json \
|
||||
|
||||
echo "Training completed for ${EXP_NAME}"
|
||||
64
post-training/VLM-R1/run_scripts/run_grpo_rec_lora.sh
Normal file
@@ -0,0 +1,64 @@
|
||||
PROJECT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )"
|
||||
export REPO_HOME="${PROJECT_ROOT}"
|
||||
echo "REPO_HOME: $REPO_HOME"
|
||||
# on remote
|
||||
data_paths="/training/shz/dataset/vlm-r1/rec_jsonsl_train/refcoco_train.jsonl:/training/shz/dataset/vlm-r1/rec_jsonsl_train/refcocop_train.jsonl:/training/shz/dataset/vlm-r1/rec_jsonsl_train/refcocog_train.jsonl"
|
||||
image_folders="/training/shz/dataset/coco:/training/shz/dataset/coco:/training/shz/dataset/coco"
|
||||
model_path="/training/models/Qwen2.5-VL-3B-Instruct"
|
||||
is_reward_customized_from_vlm_module=True
|
||||
echo "data_paths: $data_paths"
|
||||
echo "image_folders: $image_folders"
|
||||
|
||||
export EXP_NAME="Qwen2.5-VL-3B-Instruct-rec-lora" # TODO: change this to your own experiment name
|
||||
TASK_TYPE="rec"
|
||||
cd ${REPO_HOME}/src/open-r1-multimodal
|
||||
|
||||
export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
|
||||
# create the run directory and log file
|
||||
mkdir -p ${REPO_HOME}/runs/${EXP_NAME}/log
|
||||
export LOG_PATH="${REPO_HOME}/runs/${EXP_NAME}/log/debug_log.$(date +%Y-%m-%d-%H-%M-%S).txt"
|
||||
# MAX_STEPS=1200 # TODO: change this to your own max steps
|
||||
|
||||
|
||||
# export WANDB_DISABLED=true
|
||||
# CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6
|
||||
torchrun --nproc_per_node="8" \
|
||||
--nnodes="1" \
|
||||
--node_rank="0" \
|
||||
--master_addr="127.0.0.1" \
|
||||
--master_port="12349" \
|
||||
src/open_r1/grpo_jsonl.py \
|
||||
--use_vllm False \
|
||||
--output_dir ${REPO_HOME}/checkpoints/rl/${EXP_NAME} \
|
||||
--resume_from_checkpoint True \
|
||||
--model_name_or_path $model_path \
|
||||
--data_file_paths $data_paths \
|
||||
--image_folders $image_folders \
|
||||
--is_reward_customized_from_vlm_module $is_reward_customized_from_vlm_module \
|
||||
--task_type $TASK_TYPE \
|
||||
--per_device_train_batch_size 8 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--gradient_checkpointing true \
|
||||
--logging_steps 1 \
|
||||
--num_train_epochs 2 \
|
||||
--bf16 \
|
||||
--attn_implementation flash_attention_2 \
|
||||
--run_name ${EXP_NAME} \
|
||||
--data_seed 42 \
|
||||
--save_steps 100 \
|
||||
--num_generations 8 \
|
||||
--max_completion_length 2048 \
|
||||
--reward_funcs accuracy format \
|
||||
--beta 0.04 \
|
||||
--report_to wandb \
|
||||
--dataset-name this_is_not_used \
|
||||
--deepspeed ${REPO_HOME}/src/open-r1-multimodal/local_scripts/zero2.json \
|
||||
--learning_rate 1e-5 \
|
||||
--use_peft true \
|
||||
--lora_r 64 \
|
||||
--lora_alpha 128 \
|
||||
--lora_dropout 0.05 \
|
||||
--lora_task_type CAUSAL_LM \
|
||||
--freeze_vision_modules true
|
||||
|
||||
echo "Training completed for ${EXP_NAME}"
|
||||
60
post-training/VLM-R1/run_scripts/run_grpo_rec_more_params.sh
Normal file
@@ -0,0 +1,60 @@
|
||||
PROJECT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )"
|
||||
export REPO_HOME="${PROJECT_ROOT}"
|
||||
echo "REPO_HOME: $REPO_HOME"
|
||||
# on remote
|
||||
data_paths="/training/shz/dataset/vlm-r1/rec_jsonsl_train/refcoco_train.jsonl:/training/shz/dataset/vlm-r1/rec_jsonsl_train/refcocop_train.jsonl:/training/shz/dataset/vlm-r1/rec_jsonsl_train/refcocog_train.jsonl"
|
||||
image_folders="/training/shz/dataset/coco:/training/shz/dataset/coco:/training/shz/dataset/coco"
|
||||
model_path="/training/models/Qwen2.5-VL-3B-Instruct"
|
||||
is_reward_customized_from_vlm_module=True
|
||||
echo "data_paths: $data_paths"
|
||||
echo "image_folders: $image_folders"
|
||||
|
||||
export EXP_NAME="Qwen2.5-VL-3B-Instruct-rec-more-params" # TODO: change this to your own experiment name
|
||||
TASK_TYPE="rec"
|
||||
cd ${REPO_HOME}/src/open-r1-multimodal
|
||||
|
||||
export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
|
||||
# create the run directory and log file
|
||||
mkdir -p ${REPO_HOME}/runs/${EXP_NAME}/log
|
||||
export LOG_PATH="${REPO_HOME}/runs/${EXP_NAME}/log/debug_log.$(date +%Y-%m-%d-%H-%M-%S).txt"
|
||||
# MAX_STEPS=1200 # TODO: change this to your own max steps
|
||||
|
||||
|
||||
# export WANDB_DISABLED=true
|
||||
# CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6
|
||||
torchrun --nproc_per_node="8" \
|
||||
--nnodes="1" \
|
||||
--node_rank="0" \
|
||||
--master_addr="127.0.0.1" \
|
||||
--master_port="12349" \
|
||||
src/open_r1/grpo_jsonl.py \
|
||||
--use_vllm False \
|
||||
--output_dir ${REPO_HOME}/checkpoints/rl/${EXP_NAME} \
|
||||
--resume_from_checkpoint True \
|
||||
--model_name_or_path $model_path \
|
||||
--data_file_paths $data_paths \
|
||||
--image_folders $image_folders \
|
||||
--is_reward_customized_from_vlm_module $is_reward_customized_from_vlm_module \
|
||||
--task_type $TASK_TYPE \
|
||||
--per_device_train_batch_size 8 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--gradient_checkpointing true \
|
||||
--logging_steps 1 \
|
||||
--num_train_epochs 2 \
|
||||
--bf16 \
|
||||
--attn_implementation flash_attention_2 \
|
||||
--run_name ${EXP_NAME} \
|
||||
--data_seed 42 \
|
||||
--save_steps 100 \
|
||||
--num_generations 8 \
|
||||
--max_completion_length 2048 \
|
||||
--reward_funcs accuracy format \
|
||||
--beta 0.04 \
|
||||
--epsilon_high 0.28 \
|
||||
--report_to wandb \
|
||||
--dataset-name this_is_not_used \
|
||||
--deepspeed ${REPO_HOME}/src/open-r1-multimodal/local_scripts/zero3.json \
|
||||
|
||||
# epsilon_high is the additional parameter compared to the general grpo training script
|
||||
|
||||
echo "Training completed for ${EXP_NAME}"
|
||||
21
post-training/VLM-R1/setup.sh
Executable file
@@ -0,0 +1,21 @@
|
||||
# conda create -n vlm-r1 python=3.11
|
||||
# conda activate vlm-r1
|
||||
|
||||
# Install the packages in open-r1-multimodal .
|
||||
cd src/open-r1-multimodal # We edit the grpo.py and grpo_trainer.py in open-r1 repo.
|
||||
pip install -e ".[dev]"
|
||||
|
||||
# Addtional modules
|
||||
pip install wandb==0.18.3
|
||||
pip install tensorboardx
|
||||
pip install qwen_vl_utils torchvision
|
||||
pip install flash-attn --no-build-isolation
|
||||
pip install babel
|
||||
pip install python-Levenshtein
|
||||
pip install matplotlib
|
||||
pip install pycocotools
|
||||
pip install openai
|
||||
pip install httpx[socks]
|
||||
pip install lap scikit-image open-clip-torch playwright
|
||||
playwright install
|
||||
playwright install-deps
|
||||
178
post-training/VLM-R1/src/eval/test_od_r1.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import re
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import random
|
||||
|
||||
from tqdm import tqdm
|
||||
from pprint import pprint
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
|
||||
|
||||
def extract_bbox_answer(content):
|
||||
pattern = r'```json(.*?)```'
|
||||
json_match = re.search(pattern, content, re.DOTALL)
|
||||
bbox_json = json_match.group(1).strip() if json_match else None
|
||||
|
||||
if bbox_json:
|
||||
try:
|
||||
bbox = json.loads(bbox_json)[0]['bbox_2d']
|
||||
return bbox, False
|
||||
except:
|
||||
return [0, 0, 0, 0], False
|
||||
else:
|
||||
return [0, 0, 0, 0], False
|
||||
|
||||
|
||||
def iou(box1, box2):
|
||||
inter_x1 = max(box1[0], box2[0])
|
||||
inter_y1 = max(box1[1], box2[1])
|
||||
inter_x2 = min(box1[2] - 1, box2[2] - 1)
|
||||
inter_y2 = min(box1[3] - 1, box2[3] - 1)
|
||||
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
|
||||
inter = (inter_x2 - inter_x1 + 1) * (inter_y2 - inter_y1 + 1)
|
||||
else:
|
||||
inter = 0
|
||||
union = (box1[2] - box1[0]) * (box1[3] - box1[1]) + (box2[2] - box2[0]) * (box2[3] - box2[1]) - inter
|
||||
return float(inter) / union
|
||||
|
||||
|
||||
def load_model(model_path, device_map):
|
||||
#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
device_map=device_map,
|
||||
)
|
||||
|
||||
# default processer
|
||||
processor = AutoProcessor.from_pretrained(model_path)
|
||||
|
||||
return model, processor
|
||||
|
||||
|
||||
def eval_od_r1(
|
||||
model_path, test_datasets, data_root, image_root, question_template, output_dir, batch_size=32, sample_num=500, seed=42, device_map="cuda:0"
|
||||
):
|
||||
random.seed(seed)
|
||||
model, processor = load_model(model_path, device_map)
|
||||
|
||||
for ds in test_datasets:
|
||||
print(f"Processing {ds}...")
|
||||
|
||||
ds_path = os.path.join(data_root, f"{ds}.json")
|
||||
data = json.load(open(ds_path, "r"))
|
||||
random.shuffle(data)
|
||||
data = data[:sample_num]
|
||||
messages = []
|
||||
|
||||
for x in data:
|
||||
image_path = os.path.join(image_root, x['image'])
|
||||
messages.append(
|
||||
[
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
[
|
||||
{
|
||||
"type": "image",
|
||||
"image": f"file://{image_path}"
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": question_template.format(Question=x['normal_caption'])
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
all_outputs = [] # List to store all answers
|
||||
|
||||
# Process data
|
||||
for i in tqdm(range(0, len(messages), batch_size)):
|
||||
batch_messages = messages[i:i + batch_size]
|
||||
|
||||
# Preparation for inference
|
||||
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
|
||||
|
||||
image_inputs, video_inputs = process_vision_info(batch_messages)
|
||||
inputs = processor(
|
||||
text=text,
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(device_map)
|
||||
|
||||
# Inference: Generation of the output
|
||||
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
|
||||
|
||||
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
|
||||
batch_output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
all_outputs.extend(batch_output_text)
|
||||
|
||||
final_output = []
|
||||
correct_number = 0
|
||||
|
||||
for input_example, model_output in zip(data, all_outputs):
|
||||
original_output = model_output
|
||||
ground_truth = input_example['solution']
|
||||
ground_truth_normalized = input_example['normalized_solution']
|
||||
model_answer, normalized = extract_bbox_answer(original_output)
|
||||
|
||||
# Count correct answers
|
||||
correct = 0
|
||||
if model_answer is not None:
|
||||
iou_value = iou(model_answer, ground_truth_normalized if normalized else ground_truth)
|
||||
if iou_value > 0.5:
|
||||
correct = 1
|
||||
correct_number += correct
|
||||
|
||||
# Create a result dictionary for this example
|
||||
result = {
|
||||
"question": question_template.format(Question=input_example['normal_caption']),
|
||||
"ground_truth": ground_truth if not normalized else ground_truth_normalized,
|
||||
"model_output": original_output,
|
||||
"extracted_answer": model_answer,
|
||||
"correct": correct,
|
||||
"iou": iou_value
|
||||
}
|
||||
final_output.append(result)
|
||||
|
||||
# Calculate and print accuracy
|
||||
accuracy = correct_number / len(data) * 100
|
||||
print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
|
||||
|
||||
# Save results to a JSON file
|
||||
result_path = os.path.join(output_dir, f"{os.path.basename(model_path)}", f"{ds}_od_r1.json")
|
||||
os.makedirs(os.path.dirname(result_path), exist_ok=True)
|
||||
with open(result_path, "w") as f:
|
||||
json.dump({"accuracy": accuracy, "results": final_output}, f, indent=2)
|
||||
|
||||
print(f"Results saved to {result_path}")
|
||||
print('-' * 100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_path = '' # Add the path to the model
|
||||
data_root = '' # Add the data root
|
||||
test_datasets = ['refcoco_val', 'refcocop_val', 'refcocog_val'] # modify the datasets
|
||||
image_root = '' # Add the image root
|
||||
output_dir = 'logs' # Add the output directory, default is logs
|
||||
device_map = 'cuda:0' # select the device, default is cuda:0
|
||||
|
||||
question_template = '{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format.' # modify the question template which must contain {Question}, {Question} will be replaced by the caption
|
||||
|
||||
eval_od_r1(
|
||||
model_path=model_path,
|
||||
data_root=data_root,
|
||||
test_datasets=test_datasets,
|
||||
image_root=image_root,
|
||||
question_template=question_template,
|
||||
output_dir=output_dir,
|
||||
device_map=device_map
|
||||
)
|
||||
225
post-training/VLM-R1/src/eval/test_rec_baseline.py
Normal file
@@ -0,0 +1,225 @@
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import torch
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
import os
|
||||
from pprint import pprint
|
||||
import random
|
||||
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import argparse
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
|
||||
|
||||
def setup_distributed():
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
|
||||
print(f"Process {rank}/{world_size} initialized on cuda:{local_rank}")
|
||||
return local_rank, world_size, rank
|
||||
|
||||
local_rank, world_size, rank = setup_distributed()
|
||||
device = f"cuda:{local_rank}"
|
||||
|
||||
steps = 100
|
||||
MODEL_PATH=f"/data10/shz/project/LLaMA-Factory/saves/qwen2_5_vl-3b/full/sft/checkpoint-{steps}"
|
||||
OUTPUT_PATH="./logs/rec_results_{DATASET}_qwen2_5vl_3b_instruct_sft_{STEPS}.json"
|
||||
|
||||
# MODEL_PATH = "/data10/shz/ckpt/vlm-r1-related/Qwen2.5-VL-3B-Instruct"
|
||||
# OUTPUT_PATH = "./logs/rec_results_{DATASET}_qwen2_5vl_3b_instruct_baseline_{STEPS}.json"
|
||||
|
||||
BSZ=4
|
||||
DATA_ROOT = "/data10/shz/dataset/rec/rec_jsons_processed"
|
||||
|
||||
TEST_DATASETS = ['refcoco_val', 'refcocop_val', 'refcocog_val']
|
||||
IMAGE_ROOT = "/data10/shz/dataset/coco"
|
||||
|
||||
# TEST_DATASETS = ['lisa_test']
|
||||
# IMAGE_ROOT = "/data10/shz/dataset/lisa"
|
||||
|
||||
#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
MODEL_PATH,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
device_map={"": local_rank},
|
||||
)
|
||||
|
||||
# default processer
|
||||
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
||||
|
||||
def extract_bbox_answer(content):
|
||||
bbox_pattern = r'\[(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*)\]'
|
||||
# bbox_pattern = r'\[(-?\d*\.?\d+),\s*(-?\d*\.?\d+),\s*(-?\d*\.?\d+),\s*(-?\d*\.?\d+)\]'
|
||||
bbox_match = re.search(bbox_pattern, content)
|
||||
|
||||
if bbox_match:
|
||||
bbox = [float(bbox_match.group(1)), float(bbox_match.group(2)), float(bbox_match.group(3)), float(bbox_match.group(4))]
|
||||
return bbox
|
||||
return [0, 0, 0, 0]
|
||||
|
||||
def iou(box1, box2):
|
||||
inter_x1 = max(box1[0], box2[0])
|
||||
inter_y1 = max(box1[1], box2[1])
|
||||
inter_x2 = min(box1[2]-1, box2[2]-1)
|
||||
inter_y2 = min(box1[3]-1, box2[3]-1)
|
||||
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
|
||||
inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
|
||||
else:
|
||||
inter = 0
|
||||
union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
|
||||
return float(inter)/union
|
||||
|
||||
num_samples = 2000
|
||||
for ds in TEST_DATASETS:
|
||||
if rank == 0:
|
||||
print(f"Processing {ds}...")
|
||||
ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
|
||||
data = json.load(open(ds_path, "r"))
|
||||
random.seed(42)
|
||||
random.shuffle(data)
|
||||
data = data[:num_samples]
|
||||
# QUESTION_TEMPLATE = "{Question}" if steps > 0 else "{Question} Please provide the bounding box coordinate in JSON format."
|
||||
QUESTION_TEMPLATE = "{Question} Please provide the bounding box coordinate in JSON format."
|
||||
|
||||
# Split data for distributed evaluation
|
||||
per_rank_data = len(data) // world_size
|
||||
start_idx = rank * per_rank_data
|
||||
end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
|
||||
rank_data = data[start_idx:end_idx]
|
||||
|
||||
messages = []
|
||||
|
||||
for x in rank_data:
|
||||
image_path = os.path.join(IMAGE_ROOT, x['image'])
|
||||
message = [
|
||||
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": f"file://{image_path}"
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": QUESTION_TEMPLATE.format(Question=x['problem'])
|
||||
}
|
||||
]
|
||||
}]
|
||||
messages.append(message)
|
||||
|
||||
rank_outputs = [] # List to store answers for this rank
|
||||
all_outputs = [] # List to store all answers
|
||||
|
||||
# Process data
|
||||
for i in tqdm(range(0, len(messages), BSZ), disable=rank != 0):
|
||||
batch_messages = messages[i:i + BSZ]
|
||||
|
||||
# Preparation for inference
|
||||
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
|
||||
|
||||
image_inputs, video_inputs = process_vision_info(batch_messages)
|
||||
inputs = processor(
|
||||
text=text,
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(device)
|
||||
|
||||
# Inference: Generation of the output
|
||||
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
|
||||
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
batch_output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
rank_outputs.extend(batch_output_text)
|
||||
|
||||
print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")
|
||||
|
||||
# Gather all outputs from all ranks
|
||||
all_outputs = [None] * len(data)
|
||||
rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]
|
||||
|
||||
gathered_results = [None] * world_size
|
||||
dist.all_gather_object(gathered_results, rank_results)
|
||||
|
||||
assert gathered_results[-1][-1][0] == len(data) - 1
|
||||
|
||||
# The main process will collect all results
|
||||
if rank == 0:
|
||||
for results in gathered_results:
|
||||
for idx, output in results:
|
||||
assert idx < len(all_outputs)
|
||||
all_outputs[idx] = output
|
||||
assert all_outputs[-1] is not None
|
||||
|
||||
final_output = []
|
||||
correct_number = 0
|
||||
|
||||
for input_example, model_output in zip(data, all_outputs):
|
||||
original_output = model_output
|
||||
ground_truth = input_example['solution']
|
||||
model_answer = extract_bbox_answer(original_output)
|
||||
|
||||
# Count correct answers
|
||||
correct = 0
|
||||
if model_answer is not None:
|
||||
if iou(model_answer, ground_truth) > 0.5:
|
||||
correct = 1
|
||||
correct_number += correct
|
||||
|
||||
# Create a result dictionary for this example
|
||||
result = {
|
||||
'image': input_example['image'],
|
||||
'question': input_example['problem'],
|
||||
'ground_truth': ground_truth,
|
||||
'model_output': original_output,
|
||||
'extracted_answer': model_answer,
|
||||
'correct': correct
|
||||
}
|
||||
final_output.append(result)
|
||||
|
||||
# Calculate and print accuracy
|
||||
accuracy = correct_number / len(data) * 100
|
||||
print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
|
||||
|
||||
# Save results to a JSON file
|
||||
output_path = OUTPUT_PATH.format(DATASET=ds, STEPS=steps)
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
with open(output_path, "w") as f:
|
||||
json.dump({
|
||||
'accuracy': accuracy,
|
||||
'results': final_output
|
||||
}, f, indent=2)
|
||||
|
||||
print(f"Results saved to {output_path}")
|
||||
print("-"*100)
|
||||
|
||||
# Synchronize all processes
|
||||
dist.barrier()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
233
post-training/VLM-R1/src/eval/test_rec_r1.py
Normal file
@@ -0,0 +1,233 @@
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import torch
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
import os
|
||||
from pprint import pprint
|
||||
import random
|
||||
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import argparse
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
|
||||
|
||||
def setup_distributed():
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
|
||||
return local_rank, world_size, rank
|
||||
|
||||
local_rank, world_size, rank = setup_distributed()
|
||||
device = f"cuda:{local_rank}"
|
||||
print(f"Process {rank} using {device}")
|
||||
|
||||
main_rank = 0
|
||||
steps = 100
|
||||
if rank == main_rank:
|
||||
print("Steps: ", steps)
|
||||
|
||||
RUN_NAME = "Qwen2.5-VL-3B-Instruct-rec"
|
||||
|
||||
MODEL_PATH=f"/training/shz/project/vlm-r1/VLM-R1/checkpoints/rl/{RUN_NAME}/checkpoint-{steps}"
|
||||
OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json"
|
||||
|
||||
BSZ=2
|
||||
DATA_ROOT = "/training/shz/dataset/vlm-r1/rec_jsons_processed"
|
||||
|
||||
# TEST_DATASETS = ['refcoco_val', 'refcocop_val', 'refcocog_val']
|
||||
# IMAGE_ROOT = "/training/shz/dataset/coco"
|
||||
|
||||
|
||||
TEST_DATASETS = ['lisa_test']
|
||||
IMAGE_ROOT = "/training/shz/dataset/lisa"
|
||||
|
||||
|
||||
#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
MODEL_PATH,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
device_map={"": local_rank},
|
||||
)
|
||||
|
||||
# default processer
|
||||
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
||||
|
||||
def extract_bbox_answer(content):
|
||||
# Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
|
||||
answer_tag_pattern = r'<answer>(.*?)</answer>'
|
||||
bbox_pattern = r'\{.*\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]\s*.*\}'
|
||||
content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
|
||||
if content_answer_match:
|
||||
content_answer = content_answer_match.group(1).strip()
|
||||
bbox_match = re.search(bbox_pattern, content_answer, re.DOTALL)
|
||||
if bbox_match:
|
||||
bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
|
||||
return bbox
|
||||
return [0, 0, 0, 0]
|
||||
|
||||
def iou(box1, box2):
|
||||
inter_x1 = max(box1[0], box2[0])
|
||||
inter_y1 = max(box1[1], box2[1])
|
||||
inter_x2 = min(box1[2]-1, box2[2]-1)
|
||||
inter_y2 = min(box1[3]-1, box2[3]-1)
|
||||
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
|
||||
inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
|
||||
else:
|
||||
inter = 0
|
||||
union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
|
||||
return float(inter)/union
|
||||
|
||||
num_samples = 2000
|
||||
for ds in TEST_DATASETS:
|
||||
if rank == 0:
|
||||
print(f"Processing {ds}...")
|
||||
ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
|
||||
data = json.load(open(ds_path, "r"))
|
||||
random.seed(42)
|
||||
random.shuffle(data)
|
||||
data = data[:num_samples]
|
||||
|
||||
QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
|
||||
|
||||
# Split data for distributed evaluation
|
||||
per_rank_data = len(data) // world_size
|
||||
start_idx = rank * per_rank_data
|
||||
end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
|
||||
rank_data = data[start_idx:end_idx]
|
||||
|
||||
messages = []
|
||||
|
||||
for x in rank_data:
|
||||
image_path = os.path.join(IMAGE_ROOT, x['image'])
|
||||
message = [
|
||||
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": f"file://{image_path}"
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": QUESTION_TEMPLATE.format(Question=x['problem'])
|
||||
}
|
||||
]
|
||||
}]
|
||||
messages.append(message)
|
||||
|
||||
rank_outputs = [] # List to store answers for this rank
|
||||
all_outputs = [] # List to store all answers
|
||||
|
||||
# Process data
|
||||
for i in tqdm(range(0, len(messages), BSZ), disable=rank != main_rank):
|
||||
batch_messages = messages[i:i + BSZ]
|
||||
|
||||
# Preparation for inference
|
||||
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
|
||||
|
||||
image_inputs, video_inputs = process_vision_info(batch_messages)
|
||||
inputs = processor(
|
||||
text=text,
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(device)
|
||||
|
||||
# Inference: Generation of the output
|
||||
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
|
||||
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
batch_output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
rank_outputs.extend(batch_output_text)
|
||||
|
||||
print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")
|
||||
|
||||
# Gather all outputs from all ranks
|
||||
all_outputs = [None] * len(data)
|
||||
rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]
|
||||
|
||||
gathered_results = [None] * world_size
|
||||
dist.all_gather_object(gathered_results, rank_results)
|
||||
|
||||
assert gathered_results[-1][-1][0] == len(data) - 1
|
||||
|
||||
# The main process will collect all results
|
||||
if rank == main_rank:
|
||||
for results in gathered_results:
|
||||
for idx, output in results:
|
||||
assert idx < len(all_outputs)
|
||||
all_outputs[idx] = output
|
||||
assert all_outputs[-1] is not None
|
||||
|
||||
final_output = []
|
||||
correct_number = 0
|
||||
|
||||
for input_example, model_output in zip(data, all_outputs):
|
||||
original_output = model_output
|
||||
ground_truth = input_example['solution']
|
||||
model_answer = extract_bbox_answer(original_output)
|
||||
|
||||
# Count correct answers
|
||||
correct = 0
|
||||
if model_answer is not None:
|
||||
if iou(model_answer, ground_truth) > 0.5:
|
||||
correct = 1
|
||||
correct_number += correct
|
||||
|
||||
# Create a result dictionary for this example
|
||||
result = {
|
||||
'image': input_example['image'],
|
||||
'question': input_example['problem'],
|
||||
'ground_truth': ground_truth,
|
||||
'model_output': original_output,
|
||||
'extracted_answer': model_answer,
|
||||
'correct': correct
|
||||
}
|
||||
final_output.append(result)
|
||||
|
||||
# Calculate and print accuracy
|
||||
accuracy = correct_number / len(data) * 100
|
||||
print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
|
||||
|
||||
# Save results to a JSON file
|
||||
output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps)
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
with open(output_path, "w") as f:
|
||||
json.dump({
|
||||
'accuracy': accuracy,
|
||||
'results': final_output
|
||||
}, f, indent=2)
|
||||
|
||||
print(f"Results saved to {output_path}")
|
||||
print("-"*100)
|
||||
|
||||
# Synchronize all processes
|
||||
dist.barrier()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
231
post-training/VLM-R1/src/eval/test_rec_r1_internvl.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import torch
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
import os
|
||||
from pprint import pprint
|
||||
import random
|
||||
from transformers import AutoTokenizer, AutoProcessor, AutoModelForCausalLM
|
||||
from open_r1.vlm_modules.internvl_module import InvernVLModule
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
|
||||
|
||||
def setup_distributed():
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
|
||||
return local_rank, world_size, rank
|
||||
|
||||
local_rank, world_size, rank = setup_distributed()
|
||||
device = f"cuda:{local_rank}"
|
||||
print(f"Process {rank} using {device}")
|
||||
|
||||
main_rank = 0
|
||||
steps = 300
|
||||
if rank == main_rank:
|
||||
print("Steps: ", steps)
|
||||
|
||||
RUN_NAME = "InternVL2_5-4B_MPO-rec"
|
||||
|
||||
MODEL_PATH=f"/training/shz/project/vlm-r1/VLM-R1/checkpoints/rl/{RUN_NAME}/checkpoint-{steps}"
|
||||
OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json"
|
||||
|
||||
BSZ=4
|
||||
DATA_ROOT = "/training/shz/dataset/vlm-r1/rec_jsons_internvl"
|
||||
|
||||
# TEST_DATASETS = ['refcoco_val', 'refcocop_val', 'refcocog_val']
|
||||
# IMAGE_ROOT = "/training/shz/dataset/coco"
|
||||
|
||||
TEST_DATASETS = ['lisa_test']
|
||||
IMAGE_ROOT = "/training/shz/dataset/lisa"
|
||||
|
||||
random.seed(42)
|
||||
|
||||
vlm_module = InvernVLModule()
|
||||
|
||||
model = vlm_module.get_model_class(MODEL_PATH, {}).from_pretrained(
|
||||
MODEL_PATH,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map={"": local_rank},
|
||||
trust_remote_code=True,
|
||||
use_flash_attn=True,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
||||
vlm_module.post_model_init(model, tokenizer)
|
||||
|
||||
|
||||
def extract_bbox_answer(content):
|
||||
# Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
|
||||
answer_tag_pattern = r'<answer>(.*?)</answer>'
|
||||
bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]'
|
||||
content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
|
||||
if content_answer_match:
|
||||
content_answer = content_answer_match.group(1).strip()
|
||||
bbox_match = re.search(bbox_pattern, content_answer, re.DOTALL)
|
||||
if bbox_match:
|
||||
bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
|
||||
return bbox
|
||||
return [0, 0, 0, 0]
|
||||
|
||||
def iou(box1, box2):
|
||||
inter_x1 = max(box1[0], box2[0])
|
||||
inter_y1 = max(box1[1], box2[1])
|
||||
inter_x2 = min(box1[2]-1, box2[2]-1)
|
||||
inter_y2 = min(box1[3]-1, box2[3]-1)
|
||||
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
|
||||
inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
|
||||
else:
|
||||
inter = 0
|
||||
union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
|
||||
return float(inter)/union
|
||||
|
||||
from PIL import Image
|
||||
def process_vision_info(batch_messages):
|
||||
images = []
|
||||
for msg in batch_messages:
|
||||
image_path = msg[0]['content'][0]['image'].replace("file://", "")
|
||||
image = Image.open(image_path)
|
||||
images.append(image)
|
||||
return images
|
||||
|
||||
|
||||
sample_num = 2000
|
||||
tokenizer.max_anyres_num = 12
|
||||
for ds in TEST_DATASETS:
|
||||
if rank == main_rank:
|
||||
print(f"Processing {ds}...")
|
||||
ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
|
||||
data = json.load(open(ds_path, "r"))
|
||||
random.seed(42)
|
||||
random.shuffle(data)
|
||||
data = data[:sample_num]
|
||||
QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags."
|
||||
|
||||
# Split data for distributed evaluation
|
||||
per_rank_data = len(data) // world_size
|
||||
start_idx = rank * per_rank_data
|
||||
end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
|
||||
rank_data = data[start_idx:end_idx]
|
||||
|
||||
messages = []
|
||||
for x in rank_data:
|
||||
image_path = os.path.join(IMAGE_ROOT, x['image'])
|
||||
message = [
|
||||
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": f"file://{image_path}"
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": QUESTION_TEMPLATE.format(Question=x['problem'])
|
||||
}
|
||||
]
|
||||
}]
|
||||
messages.append(message)
|
||||
|
||||
rank_outputs = [] # List to store answers for this rank
|
||||
all_outputs = [] # List to store all answers
|
||||
|
||||
# Process data
|
||||
for i in tqdm(range(0, len(messages), BSZ), disable=rank != main_rank):
|
||||
batch_messages = messages[i:i + BSZ]
|
||||
prompts = vlm_module.prepare_prompt(None, [{"prompt": msg} for msg in batch_messages])
|
||||
|
||||
images = process_vision_info(batch_messages)
|
||||
|
||||
model_inputs = vlm_module.prepare_model_inputs(tokenizer, prompts, images)
|
||||
model_inputs['pixel_values'] = model_inputs['pixel_values'].to(torch.bfloat16)
|
||||
model_inputs = model_inputs.to(device)
|
||||
|
||||
outputs = model.generate(**{k:v for k,v in model_inputs.items() if k not in vlm_module.get_non_generate_params()}, max_new_tokens=1024, do_sample=False, pad_token_id=tokenizer.eos_token_id)
|
||||
batch_output_text = tokenizer.batch_decode(
|
||||
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
rank_outputs.extend(batch_output_text)
|
||||
|
||||
print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")
|
||||
|
||||
# Gather all outputs from all ranks
|
||||
all_outputs = [None] * len(data)
|
||||
rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]
|
||||
|
||||
gathered_results = [None] * world_size
|
||||
dist.all_gather_object(gathered_results, rank_results)
|
||||
|
||||
assert gathered_results[-1][-1][0] == len(data) - 1
|
||||
|
||||
# The main process will collect all results
|
||||
if rank == main_rank:
|
||||
for results in gathered_results:
|
||||
for idx, output in results:
|
||||
assert idx < len(all_outputs)
|
||||
all_outputs[idx] = output
|
||||
assert all_outputs[-1] is not None
|
||||
|
||||
final_output = []
|
||||
correct_number = 0
|
||||
|
||||
for input_example, model_output in zip(data, all_outputs):
|
||||
original_output = model_output
|
||||
ground_truth = input_example['solution']
|
||||
model_answer = extract_bbox_answer(original_output)
|
||||
|
||||
# Count correct answers
|
||||
correct = 0
|
||||
if model_answer is not None and iou(model_answer, ground_truth) > 0.5:
|
||||
correct = 1
|
||||
correct_number += correct
|
||||
|
||||
# Create a result dictionary for this example
|
||||
result = {
|
||||
'image': input_example['image'],
|
||||
'question': input_example['problem'],
|
||||
'ground_truth': ground_truth,
|
||||
'model_output': original_output,
|
||||
'extracted_answer': model_answer,
|
||||
'correct': correct
|
||||
}
|
||||
final_output.append(result)
|
||||
|
||||
# Calculate and print accuracy
|
||||
accuracy = correct_number / len(data) * 100
|
||||
print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
|
||||
|
||||
# Save results to a JSON file
|
||||
output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps)
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
with open(output_path, "w") as f:
|
||||
json.dump({
|
||||
'accuracy': accuracy,
|
||||
'results': final_output
|
||||
}, f, indent=4)
|
||||
|
||||
print(f"Results saved to {output_path}")
|
||||
print("-"*100)
|
||||
|
||||
# Synchronize all processes
|
||||
dist.barrier()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
178
post-training/VLM-R1/src/open-r1-multimodal/.gitignore
vendored
Normal file
@@ -0,0 +1,178 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Temp folders
|
||||
data/
|
||||
wandb/
|
||||
scripts/
|
||||
checkpoints/
|
||||
.vscode/
|
||||
201
post-training/VLM-R1/src/open-r1-multimodal/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
20
post-training/VLM-R1/src/open-r1-multimodal/Makefile
Normal file
@@ -0,0 +1,20 @@
|
||||
.PHONY: style quality
|
||||
|
||||
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
|
||||
export PYTHONPATH = src
|
||||
|
||||
check_dirs := src
|
||||
|
||||
style:
|
||||
black --line-length 119 --target-version py310 $(check_dirs) setup.py
|
||||
isort $(check_dirs) setup.py
|
||||
|
||||
quality:
|
||||
black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
|
||||
isort --check-only $(check_dirs) setup.py
|
||||
flake8 --max-line-length 119 $(check_dirs) setup.py
|
||||
|
||||
|
||||
# Evaluation
|
||||
|
||||
evaluate:
|
||||
16
post-training/VLM-R1/src/open-r1-multimodal/configs/ddp.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -0,0 +1,42 @@
|
||||
# Model arguments
|
||||
model_name_or_path: /data/shz/ckpt/Qwen2.5-VL-3B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: /data/shz/project/vlm-r1/VLM-R1/src/open-r1-multimodal/data_script/rec.yaml
|
||||
image_root: /data/shz/dataset/coco
|
||||
dataset_configs:
|
||||
- all
|
||||
preprocessing_num_workers: 8
|
||||
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: true
|
||||
eval_strategy: "no"
|
||||
gradient_accumulation_steps: 2
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: Qwen2.5-VL-3B-Instruct
|
||||
hub_strategy: every_save
|
||||
learning_rate: 2.0e-05
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
packing: true
|
||||
max_seq_length: 4096
|
||||
max_steps: -1
|
||||
num_train_epochs: 3
|
||||
output_dir: /data/shz/project/vlm-r1/VLM-R1/output/Qwen2.5-VL-3B-Instruct-SFT
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 1
|
||||
per_device_train_batch_size: 4
|
||||
push_to_hub: false
|
||||
report_to:
|
||||
- wandb
|
||||
save_strategy: "no"
|
||||
seed: 42
|
||||
data_seed: 42
|
||||
warmup_ratio: 0.1
|
||||
@@ -0,0 +1,21 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -0,0 +1,22 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: true
|
||||
zero3_save_16bit_model: true
|
||||
zero_stage: 3
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -0,0 +1,4 @@
|
||||
datasets:
|
||||
- json_path: /data/shz/project/vlm-r1/VLM-R1/src/data/rec_jsons_processed/refcoco_train.json
|
||||
- json_path: /data/shz/project/vlm-r1/VLM-R1/src/data/rec_jsons_processed/refcocop_train.json
|
||||
- json_path: /data/shz/project/vlm-r1/VLM-R1/src/data/rec_jsons_processed/refcocog_train.json
|
||||
@@ -0,0 +1,4 @@
|
||||
datasets:
|
||||
- json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcoco_train.json
|
||||
- json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcocop_train.json
|
||||
- json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcocog_train.json
|
||||
@@ -0,0 +1,153 @@
|
||||
import argparse
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk
|
||||
from tqdm import tqdm
|
||||
|
||||
import bytedtos
|
||||
import seaborn as sns
|
||||
import yaml
|
||||
from openai import AzureOpenAI
|
||||
from PIL import Image
|
||||
from pillow_avif import AvifImagePlugin
|
||||
|
||||
|
||||
PROMPT_FORMAT = """I will provide you with an image, an original question, and its answer related to the image. Your task is to rewrite the question in such a way that answering it requires step-by-step Chain-of-Thought (CoT) reasoning with numerical or mathematical expressions where applicable. The reasoning process can include expressions like "let me think," "oh, I see," or other natural language thought expressions.
|
||||
|
||||
Please make sure your question is to ask for a certain answer with a certain value, do not ask for open-ended answer, and the answer is correct and easy to verify via simple protocol, like "2" or "A".
|
||||
|
||||
Please strictly do not include "Answer:" in the question part to avoid confusion and leakage.
|
||||
|
||||
Input Format:
|
||||
Original Question: {original_question}
|
||||
Original Answer: {original_answer}
|
||||
|
||||
Output Format:
|
||||
Question: [rewrite the question if necessary]
|
||||
Answer: [answer with reasoning steps, including calculations where applicable]
|
||||
<think>step-by-step reasoning process</think>
|
||||
<answer>easy to verify answer</answer>
|
||||
"""
|
||||
|
||||
|
||||
def get_image_data_url(image_input):
|
||||
if isinstance(image_input, str) and image_input.startswith("data:"):
|
||||
return image_input
|
||||
|
||||
if isinstance(image_input, str) and image_input.startswith("http"):
|
||||
image_input = load_image(image_input)
|
||||
|
||||
if isinstance(image_input, str):
|
||||
image_input = Image.open(image_input)
|
||||
|
||||
if not isinstance(image_input, Image.Image):
|
||||
raise ValueError("Unsupported image input type")
|
||||
|
||||
if image_input.mode != "RGB":
|
||||
image_input = image_input.convert("RGB")
|
||||
|
||||
buffer = BytesIO()
|
||||
image_input.save(buffer, format="JPEG")
|
||||
img_bytes = buffer.getvalue()
|
||||
base64_data = base64.b64encode(img_bytes).decode("utf-8")
|
||||
return f"data:image/jpeg;base64,{base64_data}"
|
||||
|
||||
|
||||
def gpt4o_query(image, prompt, max_retries=5, initial_delay=3):
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
data_url_list = [get_image_data_url(image)]
|
||||
client = AzureOpenAI(
|
||||
azure_endpoint="YOUR_AZURE_ENDPOINT",
|
||||
api_version="2023-07-01-preview",
|
||||
api_key="YOUR_API_KEY",
|
||||
)
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an expert to analyze the image and provide useful information for users.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
for data_url in data_url_list:
|
||||
messages[1]["content"].insert(
|
||||
0, {"type": "image_url", "image_url": {"url": data_url}}
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o-2024-08-06",
|
||||
messages=messages,
|
||||
temperature=0.2,
|
||||
max_tokens=8192,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
except Exception as e:
|
||||
if attempt == max_retries - 1:
|
||||
raise Exception(
|
||||
f"Failed after {max_retries} attempts. Last error: {str(e)}"
|
||||
)
|
||||
delay = initial_delay * (2**attempt) + random.uniform(
|
||||
0, 0.1 * initial_delay * (2**attempt)
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
def process_single_item(example):
|
||||
try:
|
||||
image_path = example["image_path"]
|
||||
formatted_prompt = PROMPT_FORMAT.format(
|
||||
original_question=example["question"], original_answer=example["answer"]
|
||||
)
|
||||
|
||||
response = gpt4o_query(image_path, formatted_prompt)
|
||||
example["gpt4o_response"] = response
|
||||
return example
|
||||
except Exception as e:
|
||||
print(f"Error processing item: {str(e)}")
|
||||
example["gpt4o_response"] = None
|
||||
return example
|
||||
|
||||
|
||||
def main():
|
||||
dataset_path = "path/to/your/dataset"
|
||||
full_dataset = load_from_disk(dataset_path)
|
||||
|
||||
processed_dataset = full_dataset.map(
|
||||
function=partial(process_single_item),
|
||||
num_proc=256,
|
||||
desc="Processing dataset with GPT-4o",
|
||||
keep_in_memory=True,
|
||||
)
|
||||
|
||||
output_path = f"{dataset_path}_processed"
|
||||
processed_dataset.save_to_disk(output_path)
|
||||
print(f"Processed dataset saved to: {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,61 @@
|
||||
export HF_HOME="<CACHE_DIR>"
|
||||
export HF_TOKEN="<HF_TOKEN>"
|
||||
export HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||
|
||||
export API_TYPE="<API_TYPE>"
|
||||
export AZURE_ENDPOINT="<AZURE_ENDPOINT>"
|
||||
export AZURE_API_KEY="<API_KEY>"
|
||||
export API_VERSION="<API_VERSION>"
|
||||
export MODEL_VERSION="<MODEL_VERSION>"
|
||||
export NAVIT_ATTENTION_IMPLEMENTATION="eager"
|
||||
|
||||
# Prompt for installation with 3-second timeout
|
||||
read -t 3 -p "Do you want to install dependencies? (YES/no, timeout in 3s): " install_deps || true
|
||||
if [ "$install_deps" = "YES" ]; then
|
||||
# Prepare the environment
|
||||
pip3 install --upgrade pip
|
||||
pip3 install -U setuptools
|
||||
|
||||
cd <PROJECT_ROOT>
|
||||
if [ ! -d "maas_engine" ]; then
|
||||
git clone <REPO_URL>
|
||||
else
|
||||
echo "maas_engine directory already exists, skipping clone"
|
||||
fi
|
||||
cd maas_engine
|
||||
git pull
|
||||
git checkout <BRANCH_NAME>
|
||||
pip3 install --no-cache-dir --no-build-isolation -e ".[standalone]"
|
||||
|
||||
current_version=$(pip3 show transformers | grep Version | cut -d' ' -f2)
|
||||
if [ "$current_version" != "4.46.2" ]; then
|
||||
echo "Installing transformers 4.46.2 (current version: $current_version)"
|
||||
pip3 install transformers==4.46.2
|
||||
else
|
||||
echo "transformers 4.46.2 is already installed"
|
||||
fi
|
||||
|
||||
cd <LMMS_EVAL_DIR>
|
||||
rm -rf <TARGET_DIR>
|
||||
pip3 install -e .
|
||||
pip3 install -U pydantic
|
||||
pip3 install Levenshtein
|
||||
pip3 install nltk
|
||||
python3 -c "import nltk; nltk.download('wordnet', quiet=True); nltk.download('punkt', quiet=True)"
|
||||
fi
|
||||
|
||||
TASKS=mmmu_val,mathvista_testmini,mmmu_pro
|
||||
MODEL_BASENAME=qwen2_vl
|
||||
|
||||
model_checkpoint="<MODEL_CHECKPOINT_PATH>"
|
||||
echo "MODEL_BASENAME: ${MODEL_BASENAME}"
|
||||
cd <LMMS_EVAL_DIR>
|
||||
|
||||
python3 -m accelerate.commands.launch --num_processes=8 --main_process_port=12345 lmms_eval \
|
||||
--model qwen2_vl \
|
||||
--model_args=pretrained=${model_checkpoint},max_pixels=2359296 \
|
||||
--tasks ${TASKS} \
|
||||
--batch_size 1 \
|
||||
--log_samples \
|
||||
--log_samples_suffix ${MODEL_BASENAME} \
|
||||
--output_path ./logs
|
||||
@@ -0,0 +1,166 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import pandas as pd
|
||||
import random
|
||||
from typing import List, Dict
|
||||
import numpy as np
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
import datasets
|
||||
|
||||
import io
|
||||
from datasets import load_dataset, load_from_disk, concatenate_datasets
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from pillow_avif import AvifImagePlugin
|
||||
from datasets import Dataset
|
||||
import json
|
||||
import yaml
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import random
|
||||
import base64
|
||||
from openai import AzureOpenAI
|
||||
import concurrent.futures
|
||||
from typing import List, Dict
|
||||
import argparse
|
||||
import time
|
||||
|
||||
|
||||
def extract_problem_solution(gpt4o_response):
|
||||
# Split the response into parts
|
||||
parts = gpt4o_response.split("<think>")
|
||||
|
||||
# Extract the problem (first part before any <think> tags)
|
||||
problem = parts[0].strip()
|
||||
# Remove "Question:" prefix if it exists
|
||||
problem = re.sub(r"^Question:\s*", "", problem)
|
||||
# Remove "Answer:" at the end of the problem
|
||||
problem = re.sub(r"\s*Answer:\s*$", "", problem).strip()
|
||||
|
||||
# Combine all the reasoning steps into a single <think> block
|
||||
think_parts = [p.split("</think>")[0].strip() for p in parts[1:] if "</think>" in p]
|
||||
solution = f"<think>{' '.join(think_parts)}</think>"
|
||||
|
||||
# Add the final answer if it exists, removing "Answer:" prefix
|
||||
if "<answer>" in gpt4o_response:
|
||||
final_answer = (
|
||||
gpt4o_response.split("<answer>")[-1].split("</answer>")[0].strip()
|
||||
)
|
||||
final_answer = re.sub(r"^Answer:\s*", "", final_answer)
|
||||
solution += f"\n\n<answer>{final_answer}</answer>"
|
||||
|
||||
return problem, solution
|
||||
|
||||
|
||||
def load_image_from_path(image_path):
|
||||
try:
|
||||
img = Image.open(image_path)
|
||||
return img
|
||||
except Exception as e:
|
||||
print(f"Error loading image {image_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def process_raw_data(raw_data):
|
||||
# Parse the raw data if it's a string
|
||||
if isinstance(raw_data, str):
|
||||
data = json.loads(raw_data)
|
||||
else:
|
||||
data = raw_data
|
||||
|
||||
# Extract problem and solution
|
||||
try:
|
||||
problem, solution = extract_problem_solution(data["gpt4o_response"])
|
||||
image = load_image_from_path(data["image_path"])
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"problem": problem,
|
||||
"solution": solution,
|
||||
"original_question": data["question"],
|
||||
"original_answer": data["answer"],
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error processing data {data}: {str(e)}")
|
||||
return {
|
||||
"image": None,
|
||||
"problem": None,
|
||||
"solution": None,
|
||||
"original_question": None,
|
||||
"original_answer": None,
|
||||
}
|
||||
|
||||
|
||||
raw_data_list = [
|
||||
"/path/to/reasoning_data_with_response_90k_verified",
|
||||
]
|
||||
|
||||
raw_data = concatenate_datasets([load_from_disk(path) for path in raw_data_list])
|
||||
|
||||
processed_data = raw_data.map(process_raw_data, num_proc=128).shuffle(seed=42)
|
||||
|
||||
hf_dict = {
|
||||
"image": [],
|
||||
"problem": [],
|
||||
"solution": [],
|
||||
"original_question": [],
|
||||
"original_answer": [],
|
||||
}
|
||||
|
||||
for item in tqdm(processed_data):
|
||||
hf_dict["image"].append(item["image"])
|
||||
hf_dict["problem"].append(item["problem"])
|
||||
hf_dict["solution"].append(item["solution"])
|
||||
hf_dict["original_question"].append(item["original_question"])
|
||||
hf_dict["original_answer"].append(item["original_answer"])
|
||||
|
||||
|
||||
features = datasets.Features(
|
||||
{
|
||||
"image": datasets.Image(),
|
||||
"problem": datasets.Value("string"),
|
||||
"solution": datasets.Value("string"),
|
||||
"original_question": datasets.Value("string"),
|
||||
"original_answer": datasets.Value("string"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def has_empty_tags(text):
|
||||
# Pattern to match empty tags like <tag></tag>
|
||||
pattern = r"<[^>]+></[^>]+>"
|
||||
return bool(re.search(pattern, text))
|
||||
|
||||
|
||||
def has_answer_pattern(text):
|
||||
if "Answer:" in text:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def has_valid_image_size(example): # for Qwen2-VL-2B's processor requirement
|
||||
# Assuming the image is in a format that can be checked for dimensions
|
||||
# You might need to adjust this depending on how the image is stored in your dataset
|
||||
try:
|
||||
image = example["image"] # or however your image is accessed
|
||||
if isinstance(image, dict) and "height" in image and "width" in image:
|
||||
return image["height"] >= 28 and image["width"] >= 28
|
||||
# If image is a PIL Image or similar
|
||||
return image.height >= 28 and image.width >= 28
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
ds = datasets.Dataset.from_dict(hf_dict, features=features)
|
||||
ds = ds.filter(
|
||||
lambda x: not has_empty_tags(x["solution"])
|
||||
and not has_answer_pattern(x["problem"])
|
||||
and has_valid_image_size(x)
|
||||
and x["image"] is not None,
|
||||
num_proc=128,
|
||||
)
|
||||
# Push to Hugging Face Hub
|
||||
ds.push_to_hub("path/to/your/dataset")
|
||||
@@ -0,0 +1,68 @@
|
||||
#!/bin/bash
|
||||
|
||||
export NCCL_BLOCKING_WAIT=0
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
export OMP_NUM_THREADS=8
|
||||
export NCCL_IB_DISABLE=0
|
||||
export NCCL_IB_GID_INDEX=3
|
||||
export NCCL_SOCKET_IFNAME=eth0
|
||||
export NCCL_DEBUG=INFO
|
||||
|
||||
# CONFIG Huggingface
|
||||
# export HF_TOKEN="<PLACEHOLDER_HF_TOKEN_1>"
|
||||
export HF_TOKEN="<PLACEHOLDER_HF_TOKEN_2>"
|
||||
export HF_HOME="$HOME/.cache/huggingface"
|
||||
export HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||
|
||||
export NCCL_DEBUG=INFO
|
||||
|
||||
GPUS="0,1,2,3,4,5,6,7"
|
||||
|
||||
# 取 worker0 第一个 port
|
||||
ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
|
||||
port=${ports[0]}
|
||||
port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
|
||||
|
||||
echo "total workers: ${ARNOLD_WORKER_NUM}"
|
||||
echo "cur worker id: ${ARNOLD_ID}"
|
||||
echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
|
||||
echo "master ip: ${METIS_WORKER_0_HOST}"
|
||||
echo "master port: ${port}"
|
||||
echo "master port in cmd: ${port_in_cmd}"
|
||||
|
||||
# export WANDB_BASE_URL=https://api.wandb.ai
|
||||
# export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_1>"
|
||||
# wandb login $WANDB_API_KEY
|
||||
|
||||
export WANDB_BASE_URL=https://api.wandb.ai
|
||||
export WANDB_PROJECT=vision-reasoning
|
||||
export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_2>"
|
||||
export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
|
||||
wandb login $WANDB_API_KEY
|
||||
|
||||
cd /home/tiger/multimodal-open-r1
|
||||
# pip3 install vllm==0.6.6.post1
|
||||
pip3 install -e ".[dev]"
|
||||
pip3 install wandb==0.18.3
|
||||
|
||||
torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
|
||||
--nnodes="${ARNOLD_WORKER_NUM}" \
|
||||
--node_rank="${ARNOLD_ID}" \
|
||||
--master_addr="${METIS_WORKER_0_HOST}" \
|
||||
--master_port="${port_in_cmd}" \
|
||||
src/open_r1/grpo.py \
|
||||
--deepspeed scripts/zero3.json \
|
||||
--output_dir Aria-GRPO-mini_cot_80k \
|
||||
--model_name_or_path rhymes-ai/Aria \
|
||||
--dataset_name luodian/mini_cot_80k \
|
||||
--max_prompt_length 8192 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--logging_steps 1 \
|
||||
--bf16 \
|
||||
--report_to wandb \
|
||||
--gradient_checkpointing true \
|
||||
--attn_implementation eager \
|
||||
--save_total_limit 8 \
|
||||
--num_train_epochs 1 \
|
||||
--run_name $WANDB_RUN_NAME
|
||||
@@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
|
||||
export NCCL_BLOCKING_WAIT=0
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
export OMP_NUM_THREADS=8
|
||||
export NCCL_IB_DISABLE=0
|
||||
export NCCL_IB_GID_INDEX=3
|
||||
export NCCL_SOCKET_IFNAME=eth0
|
||||
export NCCL_DEBUG=INFO
|
||||
|
||||
GPUS="0,1,2,3,4,5,6,7"
|
||||
|
||||
# 取 worker0 第一个 port
|
||||
ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
|
||||
port=${ports[0]}
|
||||
port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
|
||||
|
||||
echo "total workers: ${ARNOLD_WORKER_NUM}"
|
||||
echo "cur worker id: ${ARNOLD_ID}"
|
||||
echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
|
||||
echo "master ip: ${METIS_WORKER_0_HOST}"
|
||||
echo "master port: ${port}"
|
||||
echo "master port in cmd: ${port_in_cmd}"
|
||||
|
||||
# export WANDB_BASE_URL=https://api.wandb.ai
|
||||
# export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_1>"
|
||||
# wandb login $WANDB_API_KEY
|
||||
|
||||
export WANDB_BASE_URL=https://api.wandb.ai
|
||||
export WANDB_PROJECT=vision-reasoning
|
||||
export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_2>"
|
||||
export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
|
||||
wandb login $WANDB_API_KEY
|
||||
|
||||
cd /home/tiger/multimodal-open-r1
|
||||
# pip3 install vllm==0.6.6.post1
|
||||
pip3 install -e ".[dev]"
|
||||
pip3 install wandb==0.18.3
|
||||
|
||||
torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
|
||||
--nnodes="${ARNOLD_WORKER_NUM}" \
|
||||
--node_rank="${ARNOLD_ID}" \
|
||||
--master_addr="${METIS_WORKER_0_HOST}" \
|
||||
--master_port="${port_in_cmd}" \
|
||||
src/open_r1/grpo.py \
|
||||
--deepspeed scripts/zero3.json \
|
||||
--output_dir checkpoints/${WANDB_RUN_NAME} \
|
||||
--model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
|
||||
--dataset_name luodian/${DATASET_NAME} \
|
||||
--max_prompt_length 8192 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--logging_steps 1 \
|
||||
--bf16 \
|
||||
--report_to wandb \
|
||||
--gradient_checkpointing true \
|
||||
--attn_implementation flash_attention_2 \
|
||||
--max_pixels 2359296 \
|
||||
--save_total_limit 8 \
|
||||
--num_train_epochs 1 \
|
||||
--run_name $WANDB_RUN_NAME
|
||||
@@ -0,0 +1,41 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": "auto",
|
||||
"eps": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "none",
|
||||
"pin_memory": true
|
||||
},
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": false,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 2e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 100,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "none",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "none",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 100,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: true
|
||||
zero3_save_16bit_model: true
|
||||
zero_stage: 3
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -0,0 +1,48 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": "auto",
|
||||
"eps": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"steps_per_print": 1e5,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 1e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 1e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": true,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
output_dir: /path/to/output/runs/Qwen2.5-VL-3B-Idefics-V3-RSN-ai2d-500steps
|
||||
model_name_or_path: /path/to/models/Qwen2.5-VL-3B-Instruct
|
||||
dataset_name: Idefics-ai2d
|
||||
data_file_paths: /path/to/data/ai2d.jsonl
|
||||
image_folders: /path/to/images
|
||||
max_prompt_length: 1024
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
logging_steps: 1
|
||||
bf16: true
|
||||
report_to: wandb
|
||||
gradient_checkpointing: false
|
||||
deepspeed: /path/to/config/zero3.json
|
||||
attn_implementation: flash_attention_2
|
||||
max_pixels: 401408
|
||||
max_steps: 500
|
||||
run_name: Qwen2.5-VL-3B-Idefics-V3-RSN-ai2d-500steps-multinode
|
||||
save_steps: 100
|
||||
save_total_limit: 3
|
||||
save_only_model: true
|
||||
num_generations: 8
|
||||
@@ -0,0 +1,145 @@
|
||||
#!/bin/bash
|
||||
|
||||
RUN_NAME=multinode_training # assume there is a ${RUN_NAME}_args.yaml file in the current directory
|
||||
|
||||
declare -A node2ip_map
|
||||
node2ip_map=(
|
||||
["node1"]="192.168.1.101"
|
||||
["node2"]="192.168.1.102"
|
||||
["node3"]="192.168.1.103"
|
||||
["node4"]="192.168.1.104"
|
||||
)
|
||||
|
||||
# Default nodes if no arguments provided
|
||||
DEFAULT_NODES=("node1" "node2")
|
||||
|
||||
# Local codebase path in file system
|
||||
LOCAL_CODEBASE_PATH="/path/to/your/codebase"
|
||||
|
||||
# Use provided nodes or default nodes
|
||||
if [ "$#" -ge 1 ]; then
|
||||
NODES=("$@")
|
||||
else
|
||||
NODES=("${DEFAULT_NODES[@]}")
|
||||
echo "Using default nodes: ${NODES[*]}"
|
||||
fi
|
||||
|
||||
# Add this debug line
|
||||
echo "All nodes in order: ${NODES[@]}"
|
||||
|
||||
TOTAL_NODES=${#NODES[@]}
|
||||
MASTER_NODE=${NODES[0]}
|
||||
MASTER_PORT=12345
|
||||
|
||||
# Get project root directory (using the directory where this script is located)
|
||||
PROJECT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
echo "Project root directory: $PROJECT_ROOT"
|
||||
|
||||
# Get master node IP address
|
||||
echo "MASTER_NODE: $MASTER_NODE"
|
||||
MASTER_IP="${node2ip_map[$MASTER_NODE]}"
|
||||
echo "Master node IP: $MASTER_IP"
|
||||
|
||||
# Create log directory for each node
|
||||
LOG_DIR="path/to/your/log/dir"
|
||||
mkdir -p $LOG_DIR
|
||||
|
||||
# Generate docker-compose.yml
|
||||
echo "Generating docker-compose.yml..."
|
||||
cat > docker-compose.yml << EOL
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
trainer:
|
||||
image: your/training-image:tag
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
shm_size: '8gb'
|
||||
volumes:
|
||||
- /path/to/data:/data
|
||||
- $LOCAL_CODEBASE_PATH/src:/workspace/src
|
||||
environment:
|
||||
- MASTER_ADDR=\${MASTER_ADDR:-$MASTER_IP}
|
||||
- MASTER_PORT=\${MASTER_PORT:-12345}
|
||||
- NODE_RANK=\${NODE_RANK:-0}
|
||||
- WORLD_SIZE=\${WORLD_SIZE:-4}
|
||||
- DEBUG_MODE=true
|
||||
- LOG_PATH=${LOG_DIR}/debug_log.txt
|
||||
- WANDB_API_KEY=your_wandb_api_key # Optional: for logging with weights & biases
|
||||
- WANDB_PROJECT=your_project_name
|
||||
- WANDB_RUN_NAME=${RUN_NAME}-$(date +%Y-%m-%d-%H-%M-%S)
|
||||
- PYTHONPATH=/workspace/src
|
||||
network_mode: "host"
|
||||
command: /bin/bash
|
||||
working_dir: /workspace
|
||||
EOL
|
||||
|
||||
# Function to build training arguments from yaml
|
||||
build_train_args() {
|
||||
args=""
|
||||
while IFS=": " read -r key value; do
|
||||
[[ -z "$key" || "$key" =~ ^[[:space:]]*# ]] && continue
|
||||
value=$(echo "$value" | sed -e 's/^[[:space:]]*//' -e 's/[[:space:]]*$//' -e 's/^"//' -e 's/"$//')
|
||||
if [[ "$value" == "true" ]]; then
|
||||
args="$args --$key"
|
||||
elif [[ "$value" == "false" ]]; then
|
||||
continue
|
||||
else
|
||||
args="$args --$key $value"
|
||||
fi
|
||||
done < ${RUN_NAME}_args.yaml
|
||||
echo "$args"
|
||||
}
|
||||
|
||||
# Get training arguments
|
||||
TRAIN_ARGS=$(build_train_args)
|
||||
echo "TRAIN_ARGS: $TRAIN_ARGS"
|
||||
|
||||
# Launch containers on each node
|
||||
NODE_RANK=0
|
||||
for host in "${NODES[@]}"; do
|
||||
LOG_FILE="$LOG_DIR/${host}_rank${NODE_RANK}.log"
|
||||
if [ "$host" = "$MASTER_NODE" ]; then
|
||||
echo "Launching on master $host with rank $NODE_RANK, logging to $LOG_FILE"
|
||||
ssh $host "cd $PROJECT_ROOT && \
|
||||
MASTER_ADDR=$MASTER_IP \
|
||||
NODE_RANK=$NODE_RANK \
|
||||
WORLD_SIZE=$TOTAL_NODES \
|
||||
sudo -E docker-compose -f docker-compose.yml run --rm trainer \
|
||||
torchrun --nproc_per_node=8 \
|
||||
--nnodes=$TOTAL_NODES \
|
||||
--node_rank=$NODE_RANK \
|
||||
--master_addr=$MASTER_IP \
|
||||
--master_port=$MASTER_PORT \
|
||||
src/train.py \
|
||||
$TRAIN_ARGS" > "$LOG_FILE" 2>&1 &
|
||||
else
|
||||
echo "Launching on $host with rank $NODE_RANK, logging to $LOG_FILE"
|
||||
ssh $host "cd $PROJECT_ROOT && \
|
||||
MASTER_ADDR=$MASTER_IP \
|
||||
NODE_RANK=$NODE_RANK \
|
||||
WORLD_SIZE=$TOTAL_NODES \
|
||||
sudo -E docker-compose -f docker-compose.yml run --rm trainer \
|
||||
torchrun --nproc_per_node=8 \
|
||||
--nnodes=$TOTAL_NODES \
|
||||
--node_rank=$NODE_RANK \
|
||||
--master_addr=$MASTER_IP \
|
||||
--master_port=$MASTER_PORT \
|
||||
src/train.py \
|
||||
$TRAIN_ARGS" > "$LOG_FILE" 2>&1 &
|
||||
fi
|
||||
|
||||
NODE_RANK=$((NODE_RANK + 1))
|
||||
done
|
||||
|
||||
echo "Jobs launched. To monitor the logs, you can:"
|
||||
echo "1. Use 'tail -f $LOG_DIR/*.log' to watch all logs"
|
||||
echo "2. Use 'tail -f $LOG_DIR/<node_name>_rank<N>.log' to watch a specific node"
|
||||
|
||||
# Wait for all background processes to complete
|
||||
wait
|
||||
@@ -0,0 +1,36 @@
|
||||
cd src/open-r1-multimodal
|
||||
|
||||
export DEBUG_MODE="true"
|
||||
# export CUDA_VISIBLE_DEVICES=4,5,6,7
|
||||
|
||||
RUN_NAME="Qwen2.5-VL-3B-GRPO-GUI_multi-image"
|
||||
export LOG_PATH="./debug_log_$RUN_NAME.txt"
|
||||
|
||||
torchrun --nproc_per_node="8" \
|
||||
--nnodes="1" \
|
||||
--node_rank="0" \
|
||||
--master_addr="127.0.0.1" \
|
||||
--master_port="12346" \
|
||||
src/open_r1/grpo_jsonl.py \
|
||||
--deepspeed local_scripts/zero3.json \
|
||||
--output_dir output/$RUN_NAME \
|
||||
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--dataset_name none \
|
||||
--image_folders /path/to/images/ \
|
||||
--data_file_paths data_jsonl/gui_multi-image.jsonl \
|
||||
--freeze_vision_modules true \
|
||||
--max_prompt_length 1024 \
|
||||
--num_generations 8 \
|
||||
--per_device_train_batch_size 8 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--logging_steps 1 \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--data_seed 42 \
|
||||
--report_to wandb \
|
||||
--gradient_checkpointing true \
|
||||
--attn_implementation flash_attention_2 \
|
||||
--num_train_epochs 2 \
|
||||
--run_name $RUN_NAME \
|
||||
--save_steps 100 \
|
||||
--save_only_model true
|
||||
@@ -0,0 +1,34 @@
|
||||
cd src/open-r1-multimodal
|
||||
|
||||
export DEBUG_MODE="true"
|
||||
# export CUDA_VISIBLE_DEVICES=4,5,6,7
|
||||
|
||||
RUN_NAME="Qwen2.5-VL-3B-GRPO-REC"
|
||||
export LOG_PATH="./debug_log_$RUN_NAME.txt"
|
||||
|
||||
torchrun --nproc_per_node="8" \
|
||||
--nnodes="1" \
|
||||
--node_rank="0" \
|
||||
--master_addr="127.0.0.1" \
|
||||
--master_port="12346" \
|
||||
src/open_r1/grpo_rec.py \
|
||||
--deepspeed local_scripts/zero3.json \
|
||||
--output_dir output/$RUN_NAME \
|
||||
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--dataset_name data_config/rec.yaml \
|
||||
--image_root <your_image_root> \
|
||||
--max_prompt_length 1024 \
|
||||
--num_generations 8 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--logging_steps 1 \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--data_seed 42 \
|
||||
--report_to wandb \
|
||||
--gradient_checkpointing false \
|
||||
--attn_implementation flash_attention_2 \
|
||||
--num_train_epochs 2 \
|
||||
--run_name $RUN_NAME \
|
||||
--save_steps 100 \
|
||||
--save_only_model true
|
||||
@@ -0,0 +1,36 @@
|
||||
cd src/open-r1-multimodal
|
||||
|
||||
export DEBUG_MODE="true"
|
||||
# export CUDA_VISIBLE_DEVICES=4,5,6,7
|
||||
|
||||
RUN_NAME="InternVL-4B-GRPO-REC"
|
||||
export LOG_PATH="./debug_log_$RUN_NAME.txt"
|
||||
|
||||
torchrun --nproc_per_node="8" \
|
||||
--nnodes="1" \
|
||||
--node_rank="0" \
|
||||
--master_addr="127.0.0.1" \
|
||||
--master_port="12346" \
|
||||
src/open_r1/grpo_rec.py \
|
||||
--deepspeed local_scripts/zero_stage2_config.json \
|
||||
--output_dir output/$RUN_NAME \
|
||||
--model_name_or_path /data10/shz/ckpt/vlm-r1-related/InternVL2_5-4B \
|
||||
--dataset_name data_config/rec_internvl.yaml \
|
||||
--image_root /data10/shz/dataset/coco \
|
||||
--freeze_vision_modules true \
|
||||
--max_anyres_num 6 \
|
||||
--max_prompt_length 1024 \
|
||||
--num_generations 8 \
|
||||
--per_device_train_batch_size 8 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--logging_steps 1 \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--data_seed 42 \
|
||||
--report_to wandb \
|
||||
--gradient_checkpointing true \
|
||||
--attn_implementation flash_attention_2 \
|
||||
--num_train_epochs 2 \
|
||||
--run_name $RUN_NAME \
|
||||
--save_steps 100 \
|
||||
--save_only_model true
|
||||
@@ -0,0 +1,43 @@
|
||||
cd src/open-r1-multimodal
|
||||
|
||||
export DEBUG_MODE="true"
|
||||
# export CUDA_VISIBLE_DEVICES=4,5,6,7
|
||||
|
||||
RUN_NAME="Qwen2.5-VL-7B-GRPO-REC-lora"
|
||||
export LOG_PATH="./debug_log_$RUN_NAME.txt"
|
||||
|
||||
torchrun --nproc_per_node="8" \
|
||||
--nnodes="1" \
|
||||
--node_rank="0" \
|
||||
--master_addr="127.0.0.1" \
|
||||
--master_port="12346" \
|
||||
src/open_r1/grpo_rec.py \
|
||||
--deepspeed local_scripts/zero2.json \
|
||||
--output_dir output/$RUN_NAME \
|
||||
--model_name_or_path Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dataset_name data_config/rec.yaml \
|
||||
--image_root <your_image_root> \
|
||||
--max_prompt_length 1024 \
|
||||
--num_generations 8 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--logging_steps 1 \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--data_seed 42 \
|
||||
--report_to wandb \
|
||||
--gradient_checkpointing true \
|
||||
--attn_implementation flash_attention_2 \
|
||||
--num_train_epochs 2 \
|
||||
--run_name $RUN_NAME \
|
||||
--save_steps 100 \
|
||||
--save_only_model true \
|
||||
--learning_rate 1e-5 \
|
||||
--use_peft true \
|
||||
--lora_r 64 \
|
||||
--lora_alpha 128 \
|
||||
--lora_dropout 0.05 \
|
||||
--lora_task_type CAUSAL_LM \
|
||||
--freeze_vision_modules true
|
||||
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
cd /workspace/VLM-R1/src/open-r1-multimodal
|
||||
|
||||
export DEBUG_MODE="true"
|
||||
export CUDA_VISIBLE_DEVICES=1,2
|
||||
|
||||
RUN_NAME="Qwen2.5-VL-3B-GRPO-REC-SFT"
|
||||
export LOG_PATH="./debug_log_$RUN_NAME.txt"
|
||||
|
||||
torchrun --nproc_per_node="2" \
|
||||
--nnodes="1" \
|
||||
--node_rank="0" \
|
||||
--master_addr="127.0.0.1" \
|
||||
--master_port="12346" \
|
||||
src/open_r1/grpo_rec.py \
|
||||
--deepspeed local_scripts/zero2.json \
|
||||
--output_dir output/$RUN_NAME \
|
||||
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--dataset_name data_config/rec.yaml \
|
||||
--image_root /data \
|
||||
--max_prompt_length 1024 \
|
||||
--num_generations 2 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--logging_steps 1 \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--data_seed 42 \
|
||||
--report_to wandb \
|
||||
--gradient_checkpointing false \
|
||||
--attn_implementation flash_attention_2 \
|
||||
--num_train_epochs 2 \
|
||||
--run_name $RUN_NAME \
|
||||
--save_steps 100 \
|
||||
--beta 0.0 \
|
||||
--epsilon_high 0.28
|
||||
41
post-training/VLM-R1/src/open-r1-multimodal/setup.cfg
Normal file
@@ -0,0 +1,41 @@
|
||||
[isort]
|
||||
default_section = FIRSTPARTY
|
||||
ensure_newline_before_comments = True
|
||||
force_grid_wrap = 0
|
||||
include_trailing_comma = True
|
||||
known_first_party = open_r1
|
||||
known_third_party =
|
||||
transformers
|
||||
datasets
|
||||
fugashi
|
||||
git
|
||||
h5py
|
||||
matplotlib
|
||||
nltk
|
||||
numpy
|
||||
packaging
|
||||
pandas
|
||||
psutil
|
||||
pytest
|
||||
rouge_score
|
||||
sacrebleu
|
||||
seqeval
|
||||
sklearn
|
||||
streamlit
|
||||
torch
|
||||
tqdm
|
||||
|
||||
line_length = 119
|
||||
lines_after_imports = 2
|
||||
multi_line_output = 3
|
||||
use_parentheses = True
|
||||
|
||||
[flake8]
|
||||
ignore = E203, E501, E741, W503, W605
|
||||
max-line-length = 119
|
||||
per-file-ignores =
|
||||
# imported but unused
|
||||
__init__.py: F401
|
||||
|
||||
[tool:pytest]
|
||||
doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
|
||||
137
post-training/VLM-R1/src/open-r1-multimodal/setup.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py
|
||||
|
||||
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
# Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
|
||||
stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
|
||||
if stale_egg_info.exists():
|
||||
print(
|
||||
(
|
||||
"Warning: {} exists.\n\n"
|
||||
"If you recently updated open_r1, this is expected,\n"
|
||||
"but it may prevent open_r1 from installing in editable mode.\n\n"
|
||||
"This directory is automatically generated by Python's packaging tools.\n"
|
||||
"I will remove it now.\n\n"
|
||||
"See https://github.com/pypa/pip/issues/5466 for details.\n"
|
||||
).format(stale_egg_info)
|
||||
)
|
||||
shutil.rmtree(stale_egg_info)
|
||||
|
||||
|
||||
# IMPORTANT: all dependencies should be listed here with their version requirements, if any.
|
||||
# * If a dependency is fast-moving (e.g. transformers), pin to the exact version
|
||||
_deps = [
|
||||
"accelerate>=1.2.1",
|
||||
"bitsandbytes>=0.43.0",
|
||||
"black>=24.4.2",
|
||||
"datasets>=3.2.0",
|
||||
"deepspeed==0.15.4",
|
||||
"distilabel[vllm,ray,openai]>=1.5.2",
|
||||
"einops>=0.8.0",
|
||||
"flake8>=6.0.0",
|
||||
"hf_transfer>=0.1.4",
|
||||
"huggingface-hub[cli]>=0.19.2,<1.0",
|
||||
"isort>=5.12.0",
|
||||
"liger_kernel==0.5.2",
|
||||
# "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]",
|
||||
"math-verify", # Used for math verification in grpo
|
||||
"packaging>=23.0",
|
||||
"parameterized>=0.9.0",
|
||||
"pytest",
|
||||
"safetensors>=0.3.3",
|
||||
"sentencepiece>=0.1.99",
|
||||
"torch>=2.5.1",
|
||||
"transformers==4.49.0",
|
||||
"trl @ git+https://github.com/huggingface/trl.git@main",
|
||||
"vllm==0.6.6.post1",
|
||||
"wandb>=0.19.1",
|
||||
"pillow",
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
#
|
||||
# tokenizers: "tokenizers==0.9.4"
|
||||
# packaging: "packaging"
|
||||
#
|
||||
# some of the values are versioned whereas others aren't.
|
||||
deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
|
||||
|
||||
|
||||
def deps_list(*pkgs):
|
||||
return [deps[pkg] for pkg in pkgs]
|
||||
|
||||
|
||||
extras = {}
|
||||
extras["tests"] = deps_list("pytest", "parameterized")
|
||||
extras["torch"] = deps_list("torch")
|
||||
extras["quality"] = deps_list("black", "isort", "flake8")
|
||||
# extras["eval"] = deps_list("lighteval", "math-verify")
|
||||
extras["eval"] = deps_list("math-verify")
|
||||
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"]
|
||||
|
||||
# core dependencies shared across the whole project - keep this to a bare minimum :)
|
||||
install_requires = [
|
||||
deps["accelerate"],
|
||||
deps["bitsandbytes"],
|
||||
deps["einops"],
|
||||
deps["datasets"],
|
||||
deps["deepspeed"],
|
||||
deps["hf_transfer"],
|
||||
deps["huggingface-hub"],
|
||||
deps["liger_kernel"],
|
||||
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
||||
deps["safetensors"],
|
||||
deps["sentencepiece"],
|
||||
deps["transformers"],
|
||||
deps["trl"],
|
||||
]
|
||||
|
||||
setup(
|
||||
name="open-r1",
|
||||
version="0.1.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future)",
|
||||
author_email="lewis@huggingface.co",
|
||||
description="Open R1",
|
||||
# long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords="llm inference-time compute reasoning",
|
||||
license="Apache",
|
||||
url="https://github.com/huggingface/open-r1",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
zip_safe=False,
|
||||
extras_require=extras,
|
||||
python_requires=">=3.10.9",
|
||||
install_requires=install_requires,
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,82 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import trl
|
||||
|
||||
|
||||
# TODO: add the shared options with a mixin to reduce code duplication
|
||||
@dataclass
|
||||
class GRPOConfig(trl.GRPOConfig):
|
||||
"""
|
||||
args for callbacks, benchmarks etc
|
||||
"""
|
||||
|
||||
benchmarks: list[str] = field(
|
||||
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
|
||||
)
|
||||
callbacks: list[str] = field(
|
||||
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
|
||||
)
|
||||
system_prompt: Optional[str] = field(
|
||||
default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
|
||||
)
|
||||
hub_model_revision: Optional[str] = field(
|
||||
default="main", metadata={"help": "The Hub model branch to push the model to."}
|
||||
)
|
||||
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
||||
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
||||
wandb_entity: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The entity to store runs under.")},
|
||||
)
|
||||
wandb_project: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The project to store runs under.")},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFTConfig(trl.SFTConfig):
|
||||
"""
|
||||
args for callbacks, benchmarks etc
|
||||
"""
|
||||
|
||||
benchmarks: list[str] = field(
|
||||
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
|
||||
)
|
||||
callbacks: list[str] = field(
|
||||
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
|
||||
)
|
||||
system_prompt: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The optional system prompt to use for benchmarking."},
|
||||
)
|
||||
hub_model_revision: Optional[str] = field(
|
||||
default="main",
|
||||
metadata={"help": "The Hub model branch to push the model to."},
|
||||
)
|
||||
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
||||
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
||||
wandb_entity: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The entity to store runs under.")},
|
||||
)
|
||||
wandb_project: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The project to store runs under.")},
|
||||
)
|
||||
@@ -0,0 +1,85 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom evaluation tasks for LightEval."""
|
||||
|
||||
from lighteval.metrics.dynamic_metrics import (
|
||||
ExprExtractionConfig,
|
||||
LatexExtractionConfig,
|
||||
multilingual_extractive_match_metric,
|
||||
)
|
||||
from lighteval.tasks.lighteval_task import LightevalTaskConfig
|
||||
from lighteval.tasks.requests import Doc
|
||||
from lighteval.utils.language import Language
|
||||
|
||||
|
||||
metric = multilingual_extractive_match_metric(
|
||||
language=Language.ENGLISH,
|
||||
fallback_mode="first_match",
|
||||
precision=5,
|
||||
gold_extraction_target=(LatexExtractionConfig(),),
|
||||
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
|
||||
aggregation_function=max,
|
||||
)
|
||||
|
||||
|
||||
def prompt_fn(line, task_name: str = None):
|
||||
"""Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
|
||||
return Doc(
|
||||
task_name=task_name,
|
||||
query=line["problem"],
|
||||
choices=[line["solution"]],
|
||||
gold_index=0,
|
||||
)
|
||||
|
||||
|
||||
# Define tasks
|
||||
aime24 = LightevalTaskConfig(
|
||||
name="aime24",
|
||||
suite=["custom"],
|
||||
prompt_function=prompt_fn,
|
||||
hf_repo="HuggingFaceH4/aime_2024",
|
||||
hf_subset="default",
|
||||
hf_avail_splits=["train"],
|
||||
evaluation_splits=["train"],
|
||||
few_shots_split=None,
|
||||
few_shots_select=None,
|
||||
generation_size=32768,
|
||||
metric=[metric],
|
||||
version=1,
|
||||
)
|
||||
math_500 = LightevalTaskConfig(
|
||||
name="math_500",
|
||||
suite=["custom"],
|
||||
prompt_function=prompt_fn,
|
||||
hf_repo="HuggingFaceH4/MATH-500",
|
||||
hf_subset="default",
|
||||
hf_avail_splits=["test"],
|
||||
evaluation_splits=["test"],
|
||||
few_shots_split=None,
|
||||
few_shots_select=None,
|
||||
generation_size=32768,
|
||||
metric=[metric],
|
||||
version=1,
|
||||
)
|
||||
|
||||
# Add tasks to the table
|
||||
TASKS_TABLE = []
|
||||
TASKS_TABLE.append(aime24)
|
||||
TASKS_TABLE.append(math_500)
|
||||
|
||||
# MODULE LOGIC
|
||||
if __name__ == "__main__":
|
||||
print([t["name"] for t in TASKS_TABLE])
|
||||
print(len(TASKS_TABLE))
|
||||
@@ -0,0 +1,156 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from distilabel.llms import OpenAILLM
|
||||
from distilabel.pipeline import Pipeline
|
||||
from distilabel.steps.tasks import TextGeneration
|
||||
|
||||
|
||||
def build_distilabel_pipeline(
|
||||
model: str,
|
||||
base_url: str = "http://localhost:8000/v1",
|
||||
prompt_column: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
max_new_tokens: int = 8192,
|
||||
num_generations: int = 1,
|
||||
) -> Pipeline:
|
||||
generation_kwargs = {"max_new_tokens": max_new_tokens}
|
||||
|
||||
if temperature is not None:
|
||||
generation_kwargs["temperature"] = temperature
|
||||
|
||||
if top_p is not None:
|
||||
generation_kwargs["top_p"] = top_p
|
||||
|
||||
with Pipeline().ray() as pipeline:
|
||||
TextGeneration(
|
||||
llm=OpenAILLM(
|
||||
base_url=base_url,
|
||||
api_key="something",
|
||||
model=model,
|
||||
# thinking can take some time...
|
||||
timeout=10 * 60,
|
||||
generation_kwargs=generation_kwargs,
|
||||
),
|
||||
input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
|
||||
input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion
|
||||
num_generations=num_generations,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
|
||||
parser.add_argument(
|
||||
"--hf-dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset to load",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-dataset-config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Dataset config to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-dataset-split",
|
||||
type=str,
|
||||
default="train",
|
||||
help="Dataset split to use",
|
||||
)
|
||||
parser.add_argument("--prompt-column", type=str, default="prompt")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model name to use for generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vllm-server-url",
|
||||
type=str,
|
||||
default="http://localhost:8000/v1",
|
||||
help="URL of the vLLM server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
help="Temperature for generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-p",
|
||||
type=float,
|
||||
help="Top-p value for generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-new-tokens",
|
||||
type=int,
|
||||
default=8192,
|
||||
help="Maximum number of new tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-generations",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generations per problem",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-output-dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
help="HuggingFace repo to push results to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--private",
|
||||
action="store_true",
|
||||
help="Whether to make the output dataset private when pushing to HF Hub",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("\nRunning with arguments:")
|
||||
for arg, value in vars(args).items():
|
||||
print(f" {arg}: {value}")
|
||||
print()
|
||||
|
||||
print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
|
||||
dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
|
||||
print("Dataset loaded!")
|
||||
|
||||
pipeline = build_distilabel_pipeline(
|
||||
model=args.model,
|
||||
base_url=args.vllm_server_url,
|
||||
prompt_column=args.prompt_column,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
num_generations=args.num_generations,
|
||||
)
|
||||
|
||||
print("Running generation pipeline...")
|
||||
distiset = pipeline.run(dataset=dataset, use_cache=False)
|
||||
print("Generation pipeline finished!")
|
||||
|
||||
if args.hf_output_dataset:
|
||||
print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
|
||||
distiset.push_to_hub(args.hf_output_dataset, private=args.private)
|
||||
print("Dataset pushed!")
|
||||
214
post-training/VLM-R1/src/open-r1-multimodal/src/open_r1/grpo.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# import debugpy
|
||||
# try:
|
||||
# # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
|
||||
# debugpy.listen(("localhost", 9501))
|
||||
# print("Waiting for debugger attach")
|
||||
# debugpy.wait_for_client()
|
||||
# except Exception as e:
|
||||
# pass
|
||||
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset, load_from_disk
|
||||
from transformers import Qwen2VLForConditionalGeneration
|
||||
|
||||
from math_verify import parse, verify
|
||||
from open_r1.trainer import VLMGRPOTrainer
|
||||
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class GRPOScriptArguments(ScriptArguments):
|
||||
"""
|
||||
Script arguments for the GRPO training script.
|
||||
|
||||
Args:
|
||||
reward_funcs (`list[str]`):
|
||||
List of reward functions. Possible values: 'accuracy', 'format'.
|
||||
"""
|
||||
|
||||
reward_funcs: list[str] = field(
|
||||
default_factory=lambda: ["accuracy", "format"],
|
||||
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
||||
)
|
||||
max_pixels: Optional[int] = field(
|
||||
default=12845056,
|
||||
metadata={"help": "Maximum number of pixels for the image"},
|
||||
)
|
||||
min_pixels: Optional[int] = field(
|
||||
default=3136,
|
||||
metadata={"help": "Minimum number of pixels for the image"},
|
||||
)
|
||||
|
||||
|
||||
def accuracy_reward(completions, solution, **kwargs):
|
||||
"""Reward function that checks if the completion is correct using either symbolic verification or exact string matching."""
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
rewards = []
|
||||
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
||||
for content, sol in zip(contents, solution):
|
||||
reward = 0.0
|
||||
# Try symbolic verification first
|
||||
try:
|
||||
answer = parse(content)
|
||||
if float(verify(answer, parse(sol))) > 0:
|
||||
reward = 1.0
|
||||
except Exception:
|
||||
pass # Continue to next verification method if this fails
|
||||
|
||||
# If symbolic verification failed, try string matching
|
||||
if reward == 0.0:
|
||||
try:
|
||||
# Extract answer from solution if it has think/answer tags
|
||||
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
|
||||
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
|
||||
|
||||
# Extract answer from content if it has think/answer tags
|
||||
content_match = re.search(r'<answer>(.*?)</answer>', content)
|
||||
student_answer = content_match.group(1).strip() if content_match else content.strip()
|
||||
|
||||
# Compare the extracted answers
|
||||
if student_answer == ground_truth:
|
||||
reward = 1.0
|
||||
except Exception:
|
||||
pass # Keep reward as 0.0 if both methods fail
|
||||
|
||||
rewards.append(reward)
|
||||
if os.getenv("DEBUG_MODE") == "true":
|
||||
log_path = os.getenv("LOG_PATH")
|
||||
# local_rank = int(os.getenv("LOCAL_RANK", 0))
|
||||
with open(log_path, "a") as f:
|
||||
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
||||
f.write(f"Content: {content}\n")
|
||||
f.write(f"Solution: {sol}\n")
|
||||
return rewards
|
||||
|
||||
|
||||
def format_reward(completions, **kwargs):
|
||||
"""Reward function that checks if the completion has a specific format."""
|
||||
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||||
completion_contents = [completion[0]["content"] for completion in completions]
|
||||
matches = [re.match(pattern, content) for content in completion_contents]
|
||||
return [1.0 if match else 0.0 for match in matches]
|
||||
|
||||
|
||||
reward_funcs_registry = {
|
||||
"accuracy": accuracy_reward,
|
||||
"format": format_reward,
|
||||
}
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
||||
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
||||
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
||||
"<think> reasoning process here </think><answer> answer here </answer>"
|
||||
)
|
||||
|
||||
|
||||
def main(script_args, training_args, model_args):
|
||||
# Get reward functions
|
||||
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
||||
print("reward_funcs:", reward_funcs)
|
||||
|
||||
# Load the dataset
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
|
||||
|
||||
# Format into conversation
|
||||
def make_conversation(example):
|
||||
return {
|
||||
"prompt": [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": example["problem"]},
|
||||
],
|
||||
}
|
||||
|
||||
# def make_conversation_image(example):
|
||||
# return {
|
||||
# "prompt": [
|
||||
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {"type": "image"},
|
||||
# {"type": "text", "text": example["problem"]},
|
||||
# ],
|
||||
# },
|
||||
# ],
|
||||
# }
|
||||
|
||||
QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
|
||||
|
||||
def make_conversation_image(example):
|
||||
return {
|
||||
"prompt": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
if "image" in dataset[script_args.dataset_train_split].features:
|
||||
print("has image in dataset")
|
||||
dataset = dataset.map(make_conversation_image) # Utilize multiprocessing for faster mapping
|
||||
# dataset = dataset.remove_columns(["original_question", "original_answer"])
|
||||
|
||||
else:
|
||||
print("no image in dataset")
|
||||
dataset = dataset.map(make_conversation)
|
||||
dataset = dataset.remove_columns("messages")
|
||||
|
||||
|
||||
trainer_cls = VLMGRPOTrainer
|
||||
|
||||
|
||||
# Initialize the GRPO trainer
|
||||
trainer = trainer_cls(
|
||||
model=model_args.model_name_or_path,
|
||||
reward_funcs=reward_funcs,
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
||||
peft_config=get_peft_config(model_args),
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
max_pixels=script_args.max_pixels,
|
||||
min_pixels=script_args.min_pixels,
|
||||
torch_dtype=model_args.torch_dtype,
|
||||
)
|
||||
|
||||
# Train and push the model to the Hub
|
||||
trainer.train()
|
||||
|
||||
# Save and push to hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
main(script_args, training_args, model_args)
|
||||
@@ -0,0 +1,254 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# import debugpy
|
||||
# try:
|
||||
# # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
|
||||
# debugpy.listen(("localhost", 9501))
|
||||
# print("Waiting for debugger attach")
|
||||
# debugpy.wait_for_client()
|
||||
# except Exception as e:
|
||||
# pass
|
||||
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import Qwen2VLForConditionalGeneration
|
||||
|
||||
from math_verify import parse, verify
|
||||
from open_r1.trainer import VLMGRPOTrainer, GRPOConfig
|
||||
from open_r1.vlm_modules import *
|
||||
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
||||
from transformers import TrainingArguments
|
||||
import yaml
|
||||
import json
|
||||
import random
|
||||
import math
|
||||
|
||||
from open_r1.qwen2_5vl_monkey_patch import monkey_patch_qwen2_5vl_flash_attn, monkey_patch_qwen2_5vl_forward
|
||||
monkey_patch_qwen2_5vl_flash_attn()
|
||||
|
||||
|
||||
# ----------------------- Main Script -----------------------
|
||||
@dataclass
|
||||
class GRPOScriptArguments(ScriptArguments):
|
||||
"""
|
||||
Script arguments for the GRPO training script.
|
||||
|
||||
Args:
|
||||
reward_funcs (`list[str]`):
|
||||
List of reward functions. Possible values: 'accuracy', 'format'.
|
||||
"""
|
||||
|
||||
reward_funcs: list[str] = field(
|
||||
default_factory=lambda: ["accuracy", "format"],
|
||||
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
||||
)
|
||||
max_pixels: Optional[int] = field(
|
||||
default=12845056,
|
||||
metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
|
||||
)
|
||||
min_pixels: Optional[int] = field(
|
||||
default=3136,
|
||||
metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
|
||||
)
|
||||
max_anyres_num: Optional[int] = field(
|
||||
default=12,
|
||||
metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
|
||||
)
|
||||
image_root: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Root directory of the image"},
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class GRPOModelConfig(ModelConfig):
|
||||
freeze_vision_modules: bool = False
|
||||
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
||||
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
||||
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
||||
"<think> reasoning process here </think><answer> answer here </answer>"
|
||||
)
|
||||
|
||||
class LazySupervisedDataset(Dataset):
|
||||
def __init__(self, data_path: str, script_args: GRPOScriptArguments, question_template: str):
|
||||
super(LazySupervisedDataset, self).__init__()
|
||||
self.script_args = script_args
|
||||
self.list_data_dict = []
|
||||
self.question_template = question_template
|
||||
|
||||
if data_path.endswith(".yaml"):
|
||||
with open(data_path, "r") as file:
|
||||
yaml_data = yaml.safe_load(file)
|
||||
datasets = yaml_data.get("datasets")
|
||||
# file should be in the format of:
|
||||
# datasets:
|
||||
# - json_path: xxxx1.json
|
||||
# sampling_strategy: first:1000
|
||||
# - json_path: xxxx2.json
|
||||
# sampling_strategy: end:3000
|
||||
# - json_path: xxxx3.json
|
||||
# sampling_strategy: random:999
|
||||
|
||||
for data in datasets:
|
||||
json_path = data.get("json_path")
|
||||
sampling_strategy = data.get("sampling_strategy", "all")
|
||||
sampling_number = None
|
||||
|
||||
if json_path.endswith(".jsonl"):
|
||||
cur_data_dict = []
|
||||
with open(json_path, "r") as json_file:
|
||||
for line in json_file:
|
||||
cur_data_dict.append(json.loads(line.strip()))
|
||||
elif json_path.endswith(".json"):
|
||||
with open(json_path, "r") as json_file:
|
||||
cur_data_dict = json.load(json_file)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {json_path}")
|
||||
|
||||
if ":" in sampling_strategy:
|
||||
sampling_strategy, sampling_number = sampling_strategy.split(":")
|
||||
if "%" in sampling_number:
|
||||
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
|
||||
else:
|
||||
sampling_number = int(sampling_number)
|
||||
|
||||
# Apply the sampling strategy
|
||||
if sampling_strategy == "first" and sampling_number is not None:
|
||||
cur_data_dict = cur_data_dict[:sampling_number]
|
||||
elif sampling_strategy == "end" and sampling_number is not None:
|
||||
cur_data_dict = cur_data_dict[-sampling_number:]
|
||||
elif sampling_strategy == "random" and sampling_number is not None:
|
||||
random.shuffle(cur_data_dict)
|
||||
cur_data_dict = cur_data_dict[:sampling_number]
|
||||
print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
|
||||
self.list_data_dict.extend(cur_data_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {data_path}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.list_data_dict)
|
||||
|
||||
def __getitem__(self, i):
|
||||
# Format into conversation
|
||||
def make_conversation(example):
|
||||
return {
|
||||
"prompt": [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": example["problem"]},
|
||||
],
|
||||
}
|
||||
QUESTION_TEMPLATE = self.question_template
|
||||
def make_conversation_image(example):
|
||||
return {
|
||||
"prompt": [
|
||||
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
example = self.list_data_dict[i]
|
||||
image_root = self.script_args.image_root
|
||||
if 'image' in example:
|
||||
image_path = os.path.join(image_root, example['image'])
|
||||
# In case the image is not found
|
||||
while not os.path.exists(image_path):
|
||||
print(f"Warning: Image {image_path} not found, randomly selecting another image")
|
||||
new_index = random.randint(0, len(self.list_data_dict)-1)
|
||||
example = self.list_data_dict[new_index]
|
||||
image_path = os.path.join(image_root, example['image'])
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
else:
|
||||
image = None
|
||||
|
||||
|
||||
return {
|
||||
'image': image,
|
||||
'problem': example['problem'],
|
||||
'solution': example['solution'],
|
||||
'prompt': make_conversation_image(example)['prompt'] if 'image' in example else make_conversation(example)['prompt'],
|
||||
}
|
||||
|
||||
|
||||
def get_vlm_module(model_name_or_path):
|
||||
if "qwen" in model_name_or_path.lower():
|
||||
return Qwen2VLModule
|
||||
elif "internvl" in model_name_or_path.lower():
|
||||
return InvernVLModule
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {model_name_or_path}")
|
||||
|
||||
def main(script_args, training_args, model_args):
|
||||
# Load the VLM module
|
||||
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
|
||||
|
||||
# Load the reward functions
|
||||
reward_funcs_registry = {
|
||||
"accuracy": vlm_module_cls.iou_reward,
|
||||
"format": vlm_module_cls.format_reward_rec,
|
||||
}
|
||||
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
||||
print("reward_funcs:", reward_funcs)
|
||||
|
||||
# Load the dataset
|
||||
dataset = LazySupervisedDataset(script_args.dataset_name, script_args, question_template=vlm_module_cls.get_question_template(task_type="rec"))
|
||||
|
||||
trainer_cls = VLMGRPOTrainer
|
||||
# Initialize the GRPO trainer
|
||||
trainer = trainer_cls(
|
||||
model=model_args.model_name_or_path,
|
||||
reward_funcs=reward_funcs,
|
||||
args=training_args,
|
||||
vlm_module=vlm_module_cls(),
|
||||
train_dataset=dataset,
|
||||
eval_dataset=None,
|
||||
peft_config=get_peft_config(model_args),
|
||||
freeze_vision_modules=model_args.freeze_vision_modules,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
max_pixels=script_args.max_pixels,
|
||||
min_pixels=script_args.min_pixels,
|
||||
max_anyres_num=script_args.max_anyres_num,
|
||||
torch_dtype=model_args.torch_dtype,
|
||||
)
|
||||
|
||||
# Train and push the model to the Hub
|
||||
trainer.train()
|
||||
|
||||
# Save and push to hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
if training_args.deepspeed and "zero3" in training_args.deepspeed:
|
||||
print("zero3 is used, qwen2_5vl forward monkey patch is applied")
|
||||
monkey_patch_qwen2_5vl_forward()
|
||||
main(script_args, training_args, model_args)
|
||||
@@ -0,0 +1,229 @@
|
||||
|
||||
# ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
|
||||
import torch
|
||||
from typing import Tuple, Optional
|
||||
def qwen2_5vl_vision_flash_attn_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
# print(111, 222, 333, 444, 555, 666, 777, 888, 999)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos().float()
|
||||
sin = emb.sin().float()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
# Add this
|
||||
cos = cos.to(torch.float)
|
||||
sin = sin.to(torch.float)
|
||||
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
|
||||
q = q.squeeze(0)
|
||||
k = k.squeeze(0)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
||||
seq_length, -1
|
||||
)
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
def monkey_patch_qwen2_5vl_flash_attn():
|
||||
Qwen2_5_VLVisionFlashAttention2.forward = qwen2_5vl_vision_flash_attn_forward
|
||||
|
||||
|
||||
# ----------------------- Fix the process pending bug when using data mixture of image-text data and pure-text under deepseed zero3-----------------------
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
|
||||
from typing import List, Union
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
def qwen2_5vl_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
|
||||
has_images_global = False
|
||||
if pixel_values is not None:
|
||||
has_images_local = torch.tensor(1, device=input_ids.device)
|
||||
else:
|
||||
has_images_local = torch.tensor(0, device=input_ids.device)
|
||||
# Use all_reduce to ensure all GPUs know if there are images to process
|
||||
torch.distributed.all_reduce(has_images_local, op=torch.distributed.ReduceOp.MAX)
|
||||
has_images_global = has_images_local.item() > 0
|
||||
|
||||
# If there are image inputs globally, ensure all GPUs call the visual model
|
||||
if has_images_global:
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.image_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
image_mask = mask_expanded.to(inputs_embeds.device)
|
||||
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# Create a dummy image data for triggering parameter synchronization
|
||||
dummy_pixel_values = torch.zeros((4, 1176), device=input_ids.device, dtype=self.visual.dtype)
|
||||
dummy_grid_thw = torch.tensor([[1, 2, 2]], device=input_ids.device)
|
||||
_ = self.visual(dummy_pixel_values, grid_thw=dummy_grid_thw)
|
||||
|
||||
# Currently, video processing is not handled.
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.video_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
video_mask = mask_expanded.to(inputs_embeds.device)
|
||||
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (
|
||||
(cache_position is not None and cache_position[0] == 0)
|
||||
or self.rope_deltas is None
|
||||
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
||||
):
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids,
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
second_per_grid_ts,
|
||||
attention_mask,
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
else:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
delta = (
|
||||
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
||||
if cache_position is not None
|
||||
else 0
|
||||
)
|
||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
||||
position_ids = position_ids.add(delta)
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=None,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Qwen2_5_VLCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
rope_deltas=self.rope_deltas,
|
||||
)
|
||||
|
||||
def monkey_patch_qwen2_5vl_forward():
|
||||
Qwen2_5_VLForConditionalGeneration.forward = qwen2_5vl_forward
|
||||
|
||||
# ----------------------- Set the Weights only as False in torch.load (In Pytorch 2.6, this is default as True)-----------------------
|
||||
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
|
||||
from deepspeed.utils import logger, log_dist
|
||||
def weigths_only_load(self, path: str, map_location=None):
|
||||
logger.info(f"[Torch] Loading checkpoint from {path}...")
|
||||
partition = torch.load(path, map_location=map_location, weights_only=False)
|
||||
logger.info(f"[Torch] Loaded checkpoint from {path}.")
|
||||
return partition
|
||||
|
||||
def monkey_patch_torch_load():
|
||||
TorchCheckpointEngine.load = weigths_only_load
|
||||
|
||||
|
||||
|
||||
346
post-training/VLM-R1/src/open-r1-multimodal/src/open_r1/sft.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Supervised fine-tuning script for decoder language models.
|
||||
|
||||
Usage:
|
||||
|
||||
# One 1 node of 8 x H100s
|
||||
accelerate launch --config_file=configs/zero3.yaml src/open_r1/sft.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
|
||||
--dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
|
||||
--learning_rate 2.0e-5 \
|
||||
--num_train_epochs 1 \
|
||||
--packing \
|
||||
--max_seq_length 4096 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--gradient_checkpointing \
|
||||
--bf16 \
|
||||
--logging_steps 5 \
|
||||
--eval_strategy steps \
|
||||
--eval_steps 100 \
|
||||
--output_dir data/Qwen2.5-1.5B-Open-R1-Distill
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, set_seed, AutoProcessor
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from open_r1.configs import SFTConfig
|
||||
from open_r1.utils.callbacks import get_callbacks
|
||||
import yaml
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
from PIL import Image
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
SFTTrainer,
|
||||
TrlParser,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from dataclasses import field
|
||||
from qwen_vl_utils import process_vision_info
|
||||
logger = logging.getLogger(__name__)
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class SFTScriptArguments(ScriptArguments):
|
||||
image_root: str = field(default=None, metadata={"help": "The root directory of the image."})
|
||||
|
||||
|
||||
processor = None
|
||||
|
||||
class LazySupervisedDataset(Dataset):
|
||||
def __init__(self, data_path: str, script_args: ScriptArguments):
|
||||
super(LazySupervisedDataset, self).__init__()
|
||||
self.script_args = script_args
|
||||
self.list_data_dict = []
|
||||
|
||||
if data_path.endswith(".yaml"):
|
||||
with open(data_path, "r") as file:
|
||||
yaml_data = yaml.safe_load(file)
|
||||
datasets = yaml_data.get("datasets")
|
||||
# file should be in the format of:
|
||||
# datasets:
|
||||
# - json_path: xxxx1.json
|
||||
# sampling_strategy: first:1000
|
||||
# - json_path: xxxx2.json
|
||||
# sampling_strategy: end:3000
|
||||
# - json_path: xxxx3.json
|
||||
# sampling_strategy: random:999
|
||||
|
||||
for data in datasets:
|
||||
json_path = data.get("json_path")
|
||||
sampling_strategy = data.get("sampling_strategy", "all")
|
||||
sampling_number = None
|
||||
|
||||
if json_path.endswith(".jsonl"):
|
||||
cur_data_dict = []
|
||||
with open(json_path, "r") as json_file:
|
||||
for line in json_file:
|
||||
cur_data_dict.append(json.loads(line.strip()))
|
||||
elif json_path.endswith(".json"):
|
||||
with open(json_path, "r") as json_file:
|
||||
cur_data_dict = json.load(json_file)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {json_path}")
|
||||
|
||||
if ":" in sampling_strategy:
|
||||
sampling_strategy, sampling_number = sampling_strategy.split(":")
|
||||
if "%" in sampling_number:
|
||||
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
|
||||
else:
|
||||
sampling_number = int(sampling_number)
|
||||
|
||||
# Apply the sampling strategy
|
||||
if sampling_strategy == "first" and sampling_number is not None:
|
||||
cur_data_dict = cur_data_dict[:sampling_number]
|
||||
elif sampling_strategy == "end" and sampling_number is not None:
|
||||
cur_data_dict = cur_data_dict[-sampling_number:]
|
||||
elif sampling_strategy == "random" and sampling_number is not None:
|
||||
random.shuffle(cur_data_dict)
|
||||
cur_data_dict = cur_data_dict[:sampling_number]
|
||||
print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
|
||||
self.list_data_dict.extend(cur_data_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {data_path}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.list_data_dict)
|
||||
|
||||
def __getitem__(self, i):
|
||||
# Format into conversation
|
||||
def make_conversation_image(example):
|
||||
image_root = self.script_args.image_root
|
||||
# print(111, image_root)
|
||||
# print(222, example['image'])
|
||||
image_path = os.path.join(image_root, example['image'])
|
||||
x1, y1, x2, y2 = example["solution"]
|
||||
normal_caption = example["normal_caption"]
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": f"file://{image_path}"},
|
||||
{"type": "text", "text": example["problem"]},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f'```json\n[\n\t{{"bbox_2d": [{int(x1)}, {int(y1)}, {int(x2)}, {int(y2)}], "label": "{normal_caption}"}}\n]\n```',
|
||||
}
|
||||
]
|
||||
|
||||
example = self.list_data_dict[i]
|
||||
example["messages"] = make_conversation_image(example)
|
||||
return example
|
||||
|
||||
|
||||
|
||||
def collate_fn(examples):
|
||||
texts = [
|
||||
processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=True)
|
||||
for example in examples
|
||||
]
|
||||
image_inputs = []
|
||||
for example in examples:
|
||||
imgs, vids = process_vision_info(example["messages"])
|
||||
image_inputs.append(imgs)
|
||||
batch = processor(
|
||||
text=texts,
|
||||
images=image_inputs,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
labels = batch["input_ids"].clone()
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
|
||||
labels[labels == image_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def main(script_args, training_args, model_args):
|
||||
# Set seed for reproducibility
|
||||
set_seed(training_args.seed)
|
||||
|
||||
###############
|
||||
# Setup logging
|
||||
###############
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Log on each process a small summary
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Model parameters {model_args}")
|
||||
logger.info(f"Script parameters {script_args}")
|
||||
logger.info(f"Data parameters {training_args}")
|
||||
|
||||
# Check for last checkpoint
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
||||
|
||||
################
|
||||
# Load datasets
|
||||
################
|
||||
|
||||
dataset = LazySupervisedDataset(script_args.dataset_name, script_args)
|
||||
|
||||
################
|
||||
# Load tokenizer
|
||||
################
|
||||
global processor
|
||||
if "vl" in model_args.model_name_or_path.lower():
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
logger.info("Using AutoProcessor for vision-language model.")
|
||||
else:
|
||||
processor = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
|
||||
)
|
||||
logger.info("Using AutoTokenizer for text-only model.")
|
||||
if hasattr(processor, "pad_token") and processor.pad_token is None:
|
||||
processor.pad_token = processor.eos_token
|
||||
elif hasattr(processor.tokenizer, "pad_token") and processor.tokenizer.pad_token is None:
|
||||
processor.tokenizer.pad_token = processor.tokenizer.eos_token
|
||||
|
||||
###################
|
||||
# Model init kwargs
|
||||
###################
|
||||
logger.info("*** Initializing model kwargs ***")
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
# training_args.model_init_kwargs = model_kwargs
|
||||
from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
|
||||
if "Qwen2-VL" in model_args.model_name_or_path:
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path, **model_kwargs
|
||||
)
|
||||
elif "Qwen2.5-VL" in model_args.model_name_or_path:
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path, **model_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {model_args.model_name_or_path}")
|
||||
############################
|
||||
# Initialize the SFT Trainer
|
||||
############################
|
||||
training_args.dataset_kwargs = {
|
||||
"skip_prepare_dataset": True,
|
||||
}
|
||||
training_args.remove_unused_columns = False
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=None,
|
||||
processing_class=processor.tokenizer,
|
||||
data_collator=collate_fn,
|
||||
peft_config=get_peft_config(model_args),
|
||||
callbacks=get_callbacks(training_args, model_args),
|
||||
)
|
||||
|
||||
###############
|
||||
# Training loop
|
||||
###############
|
||||
logger.info("*** Train ***")
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
metrics = train_result.metrics
|
||||
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
##################################
|
||||
# Save model and create model card
|
||||
##################################
|
||||
logger.info("*** Save model ***")
|
||||
trainer.save_model(training_args.output_dir)
|
||||
logger.info(f"Model saved to {training_args.output_dir}")
|
||||
|
||||
# Save everything else on main process
|
||||
kwargs = {
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": list(script_args.dataset_name),
|
||||
"dataset_tags": list(script_args.dataset_name),
|
||||
"tags": ["open-r1"],
|
||||
}
|
||||
if trainer.accelerator.is_main_process:
|
||||
trainer.create_model_card(**kwargs)
|
||||
# Restore k,v cache for fast inference
|
||||
trainer.model.config.use_cache = True
|
||||
trainer.model.config.save_pretrained(training_args.output_dir)
|
||||
#############
|
||||
# push to hub
|
||||
#############
|
||||
|
||||
if training_args.push_to_hub:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
print(script_args)
|
||||
main(script_args, training_args, model_args)
|
||||
@@ -0,0 +1,352 @@
|
||||
import tempfile
|
||||
import os
|
||||
from PIL import Image
|
||||
from playwright.sync_api import sync_playwright
|
||||
import os
|
||||
import open_clip
|
||||
import torch
|
||||
from skimage.metrics import structural_similarity as ssim
|
||||
import numpy as np
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
from lap import lapjv
|
||||
from multiprocessing import Pool
|
||||
# from selenium import webdriver
|
||||
# from selenium.webdriver.firefox.options import Options
|
||||
os.makedirs("./images", exist_ok=True)
|
||||
|
||||
def solve_assignment_lapjv(cost_matrix):
|
||||
_, col_idx, _ = lapjv(cost_matrix)
|
||||
return col_idx
|
||||
|
||||
|
||||
def process_imgs(image1, image2, max_size):
|
||||
# Get the original sizes
|
||||
width1, height1 = image1.size
|
||||
width2, height2 = image2.size
|
||||
|
||||
# Determine the new dimensions (max of both images' width and height)
|
||||
new_width = max(width1, width2)
|
||||
new_height = max(height1, height2)
|
||||
|
||||
# Pad images to the new dimensions with random values
|
||||
def pad_image(image, new_width, new_height):
|
||||
# Create a random padded background with the new dimensions
|
||||
random_padding = np.random.randint(0, 256, (new_height, new_width, 3), dtype=np.uint8)
|
||||
padded_image = Image.fromarray(random_padding)
|
||||
|
||||
# Paste the original image onto the padded background (placing in the top-left corner)
|
||||
padded_image.paste(image, (0, 0))
|
||||
|
||||
return padded_image
|
||||
|
||||
padded_image1 = pad_image(image1, new_width, new_height)
|
||||
padded_image2 = pad_image(image2, new_width, new_height)
|
||||
|
||||
# Calculate the aspect ratio for resizing to the max size
|
||||
aspect_ratio = min(max_size / new_width, max_size / new_height)
|
||||
new_size = (int(new_width * aspect_ratio), int(new_height * aspect_ratio))
|
||||
|
||||
# Resize the padded images to the specified max size
|
||||
resized_image1 = padded_image1.resize(new_size, Image.LANCZOS)
|
||||
resized_image2 = padded_image2.resize(new_size, Image.LANCZOS)
|
||||
|
||||
# resized_image1.show()
|
||||
# resized_image2.show()
|
||||
|
||||
# Convert the images to numpy arrays with dtype int16
|
||||
array1 = np.array(resized_image1).astype(np.int16)
|
||||
array2 = np.array(resized_image2).astype(np.int16)
|
||||
|
||||
return array1, array2
|
||||
|
||||
|
||||
|
||||
def calculate_emd_sim(img_array1, img_array2):
|
||||
"""img_array1 is the original image, img_array2 is the generated image"""
|
||||
if len(img_array1.shape) == 2:
|
||||
flat_array1 = img_array1.flatten()
|
||||
flat_array2 = img_array2.flatten()
|
||||
|
||||
cost_matrix = np.abs(flat_array1[:, None] - flat_array2[None, :])
|
||||
_, col_idx, _ = lapjv(cost_matrix)
|
||||
|
||||
total_min_cost = cost_matrix[np.arange(len(flat_array1)), col_idx].sum()
|
||||
max_cost = np.maximum(flat_array1, 255 - flat_array1).sum()
|
||||
normalized_min_cost = total_min_cost / max_cost
|
||||
|
||||
else:
|
||||
red1, green1, blue1 = img_array1[:, :, 0], img_array1[:, :, 1], img_array1[:, :, 2]
|
||||
red2, green2, blue2 = img_array2[:, :, 0], img_array2[:, :, 1], img_array2[:, :, 2]
|
||||
|
||||
flat_red1, flat_green1, flat_blue1 = red1.flatten(), green1.flatten(), blue1.flatten()
|
||||
flat_red2, flat_green2, flat_blue2 = red2.flatten(), green2.flatten(), blue2.flatten()
|
||||
|
||||
cost_matrix_red = np.abs(flat_red1[:, None] - flat_red2[None, :]).astype(np.float32)
|
||||
cost_matrix_green = np.abs(flat_green1[:, None] - flat_green2[None, :]).astype(np.float32)
|
||||
cost_matrix_blue = np.abs(flat_blue1[:, None] - flat_blue2[None, :]).astype(np.float32)
|
||||
|
||||
with Pool(processes=3) as pool:
|
||||
results = pool.map(solve_assignment_lapjv, [cost_matrix_red, cost_matrix_green, cost_matrix_blue])
|
||||
col_ind_red = results[0]
|
||||
col_ind_green = results[1]
|
||||
col_ind_blue = results[2]
|
||||
|
||||
min_cost_red_lapjv = cost_matrix_red[np.arange(len(flat_red1)), col_ind_red].sum()
|
||||
min_cost_green_lapjv = cost_matrix_green[np.arange(len(flat_green1)), col_ind_green].sum()
|
||||
min_cost_blue_lapjv = cost_matrix_blue[np.arange(len(flat_blue1)), col_ind_blue].sum()
|
||||
|
||||
total_min_cost_lapjv = min_cost_red_lapjv + min_cost_green_lapjv + min_cost_blue_lapjv
|
||||
max_cost = np.maximum(flat_red1, 255 - flat_red1).sum() + np.maximum(flat_green1, 255 - flat_green1).sum() + np.maximum(flat_blue1, 255 - flat_blue1).sum()
|
||||
normalized_min_cost = total_min_cost_lapjv / max_cost
|
||||
|
||||
# return {"cost": total_min_cost_lapjv, "normalized_sim": 1 - normalized_min_cost}
|
||||
return 1 - normalized_min_cost
|
||||
|
||||
def emd_similarity(image1_path, image2_path, max_size=64, mode="L"):
|
||||
"""not symmetric, the first image is the original image, the score is normalized according to the original image"""
|
||||
image1 = Image.open(image1_path).convert(mode) if type(image1_path) == str else image1_path.convert(mode)
|
||||
image2 = Image.open(image2_path).convert(mode) if type(image2_path) == str else image2_path.convert(mode)
|
||||
|
||||
array1, array2 = process_imgs(image1, image2, max_size)
|
||||
similarity = calculate_emd_sim(array1, array2)
|
||||
|
||||
return similarity
|
||||
|
||||
class CLIPScorer:
|
||||
def __init__(self, model_name='ViT-B-32-quickgelu', pretrained='openai'):
|
||||
"""
|
||||
Initializes the CLIPScorer with the specified model.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the CLIP model to use.
|
||||
pretrained (str): Specifies whether to load pre-trained weights.
|
||||
"""
|
||||
self.device = "cuda" if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else "cpu"
|
||||
self.device = "cpu" # Force CPU for compatibility
|
||||
self.model, _, self.preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
|
||||
self.model.to(self.device)
|
||||
|
||||
def score(self, img1: Image.Image, img2: Image.Image) -> float:
|
||||
"""
|
||||
Calculates the CLIP score (cosine similarity) between two images.
|
||||
|
||||
Args:
|
||||
img1 (Image.Image): The first image as a PIL Image.
|
||||
img2 (Image.Image): The second image as a PIL Image.
|
||||
|
||||
Returns:
|
||||
float: The cosine similarity score between the two images.
|
||||
"""
|
||||
# Preprocess the images
|
||||
image1 = self.preprocess(img1).unsqueeze(0).to(self.device)
|
||||
image2 = self.preprocess(img2).unsqueeze(0).to(self.device)
|
||||
|
||||
# Get the image features from CLIP using openclip
|
||||
with torch.no_grad():
|
||||
image1_features = self.model.encode_image(image1)
|
||||
image2_features = self.model.encode_image(image2)
|
||||
|
||||
# Normalize the features to unit length
|
||||
image1_features /= image1_features.norm(dim=-1, keepdim=True)
|
||||
image2_features /= image2_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Calculate cosine similarity between the two image features
|
||||
cosine_similarity = torch.nn.functional.cosine_similarity(image1_features, image2_features)
|
||||
return cosine_similarity.item()
|
||||
|
||||
clip_scorer = CLIPScorer()
|
||||
|
||||
def ssim_score(img1, img2):
|
||||
# resize images to match the size of the smaller image
|
||||
img1, img2 = process_imgs(img1, img2, 512)
|
||||
return ssim(img1, img2, channel_axis=-1, data_range=255)
|
||||
|
||||
|
||||
def mae_score(img1, img2):
|
||||
"""mean absolute error, it is a pixel-based metric"""
|
||||
img1, img2 = process_imgs(img1, img2, 1024)
|
||||
# max_mae = np.mean(np.maximum(img1, 255 - img1))
|
||||
mae = np.mean(np.abs(img1 - img2))
|
||||
# return {"mae": mae, "normalized_mae": 1 - mae / max_mae}
|
||||
return mae
|
||||
|
||||
def clip_mae(img1, img2):
|
||||
"""clip - mae/255"""
|
||||
mae = mae_score(img1, img2)
|
||||
clip = clip_scorer.score(img1, img2)
|
||||
return clip - mae/255 # scale mae by 255/200
|
||||
|
||||
|
||||
import re
|
||||
import base64
|
||||
|
||||
with open(os.getenv("CSS_PATH"), "r", encoding="utf-8") as f:
|
||||
TAILWIND_CSS = f.read()
|
||||
with open(os.getenv("PLACEHOLDER_PATH"), "rb") as f:
|
||||
PLACEHOLDER = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
def preprocess_html(html_str: str) -> str:
|
||||
# 1. Load and wrap Tailwind CSS in <style>
|
||||
style_tag = f"<style>{TAILWIND_CSS}</style>"
|
||||
html_str = html_str.replace('<link href="https://cdn.jsdelivr.net/npm/tailwindcss@2.2.19/dist/tailwind.min.css" rel="stylesheet">', style_tag)
|
||||
|
||||
# 3. Convert placeholder image to base64 and replace all occurrences
|
||||
base64_url = f"data:image/png;base64,{PLACEHOLDER}"
|
||||
html_str = html_str.replace("placeholder.jpg", base64_url)
|
||||
|
||||
return html_str
|
||||
|
||||
|
||||
|
||||
def generate_screenshot(html_content, path):
|
||||
|
||||
html_content = preprocess_html(html_content)
|
||||
|
||||
with sync_playwright() as p:
|
||||
|
||||
browser = p.chromium.launch(headless=True)
|
||||
page = browser.new_page()
|
||||
|
||||
# Set consistent rendering parameters
|
||||
page.set_viewport_size({"width": 1280, "height": 720})
|
||||
page.route("**/*", lambda route: route.continue_()) # Allow external resources
|
||||
|
||||
|
||||
# Render and screenshot
|
||||
page.set_content(html_content, timeout=100000)
|
||||
page.wait_for_load_state("networkidle", timeout=100000)
|
||||
page.screenshot(
|
||||
path=path,
|
||||
full_page=True,
|
||||
type="png",
|
||||
)
|
||||
browser.close()
|
||||
|
||||
|
||||
def rendered_score(gen_html, ref_html, score_func, verbose=True):
|
||||
"""Calculate visual similarity score between two HTML documents using screenshots."""
|
||||
|
||||
# if not verbose:
|
||||
# with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as ref_file, \
|
||||
# tempfile.NamedTemporaryFile(delete=False, suffix=".png") as gen_file:
|
||||
|
||||
# ref_path = ref_file.name
|
||||
# gen_path = gen_file.name
|
||||
|
||||
ref_name = str(abs(hash(ref_html)))[:6]
|
||||
gen_name = f"{ref_name}_{str(abs(hash(gen_html)))[:6]}"
|
||||
ref_path = f"./images/{ref_name}.png"
|
||||
gen_path = f"./images/{gen_name}.png"
|
||||
# with open(f"./images/{gen_name}.html", "w") as f:
|
||||
# f.write(gen_html)
|
||||
|
||||
|
||||
try:
|
||||
# Generate screenshots synchronously
|
||||
generate_screenshot(ref_html, ref_path)
|
||||
generate_screenshot(gen_html, gen_path)
|
||||
|
||||
# Calculate similarity score
|
||||
with Image.open(ref_path) as ref_img, Image.open(gen_path) as gen_img:
|
||||
if type(score_func) == list:
|
||||
score = []
|
||||
for func in score_func:
|
||||
score.append(func(ref_img, gen_img))
|
||||
if not verbose:
|
||||
os.remove(ref_path)
|
||||
os.remove(gen_path)
|
||||
return np.mean(score)
|
||||
|
||||
if not verbose:
|
||||
os.remove(ref_path)
|
||||
os.remove(gen_path)
|
||||
return score_func(ref_img, gen_img)
|
||||
|
||||
finally:
|
||||
# Cleanup temp files
|
||||
# os.unlink(ref_path)
|
||||
# os.unlink(gen_path)
|
||||
pass
|
||||
|
||||
|
||||
|
||||
import json
|
||||
import tqdm
|
||||
import numpy as np
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from functools import partial
|
||||
|
||||
def process_item(i, gt, base_json, sft_json, clip_scorer):
|
||||
gt_html = gt[i]
|
||||
base_html = base_json[i]
|
||||
sft_html = sft_json[i]
|
||||
|
||||
# calculate score
|
||||
base = rendered_score(gt_html, base_html, [ mae_score, clip_scorer.score])
|
||||
sft = rendered_score(gt_html, sft_html, [ mae_score, clip_scorer.score])
|
||||
|
||||
return {
|
||||
# "emd_base": emd_base,
|
||||
# "emd_sft": emd_sft,
|
||||
"base": base,
|
||||
"sft": sft,
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Load data
|
||||
# gt = json.load(open("test_selected.json"))
|
||||
# gt = [x["conversations"][1]["value"] for x in gt]
|
||||
|
||||
# # Load jsonl files
|
||||
# with open("mrweb_3b_original.jsonl", "r") as f:
|
||||
# base_json = [json.loads(line)["predict"] for line in f.readlines()]
|
||||
# with open("mrweb_3b_sft_2000.jsonl", "r") as f:
|
||||
# sft_json = [json.loads(line)["predict"] for line in f.readlines()]
|
||||
|
||||
# # Initialize CLIP scorer (once per process)
|
||||
# clip_scorer = CLIPScorer()
|
||||
|
||||
# # Create multiprocessing pool
|
||||
# from multiprocessing.pool import ThreadPool
|
||||
# num_processes = 6
|
||||
# with ThreadPool(processes=num_processes) as pool:
|
||||
# func = partial(process_item, gt=gt, base_json=base_json,
|
||||
# sft_json=sft_json, clip_scorer=clip_scorer)
|
||||
# results = list(tqdm.tqdm(
|
||||
# pool.imap(func, range(len(gt))),
|
||||
# total=len(gt)))
|
||||
|
||||
# # Summarize results
|
||||
# base_scores = [x["base"] for x in results]
|
||||
# sft_scores = [x["sft"] for x in results]
|
||||
# print("base scores: ", np.mean(base_scores))
|
||||
|
||||
# write a minimal code to test playwright screenshot
|
||||
html = """<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Interior Design Firm</title>
|
||||
<link href="https://cdn.jsdelivr.net/npm/tailwindcss@2.2.19/dist/tailwind.min.css" rel="stylesheet">
|
||||
</head>
|
||||
<body class="bg-gray-100 font-sans leading-normal max-w-3xl mx-auto">
|
||||
<section class="flex flex-col items-center justify-between p-6">
|
||||
<section class="flex-1 text-left">
|
||||
<h1 class="text-3xl font-bold mb-4 text-gray-800">Interior Design Firm</h1>
|
||||
<p class="mb-4 text-gray-600">Welcome to our interior design firm, where we specialize in creating beautiful and functional spaces that reflect your individual style and personality. Our team of experienced designers work closely with you to understand your needs and vision, and then transform that into a reality. We believe that every space is a reflection of the people who inhabit it, and we strive to make that connection as strong as possible.</p>
|
||||
<h2 class="text-2xl font-bold mb-2">Testimonials</h2>
|
||||
<p class="mb-4 text-gray-600">"I was blown away by the professionalism and creativity of the interior design team at our new home. They truly understood my vision and brought it to life. I couldn't be happier with the results." - John D.</p>
|
||||
<p class="mb-4 text-gray-600">"I was hesitant to hire a professional designer, but I'm so glad I did. The results are beyond my expectations. I'm so happy with the space I've created for myself and my family." - Jane S.</p>
|
||||
</section>
|
||||
<section class="container mx-auto ml-72 mt-4">
|
||||
<img src="placeholder.jpg" class="w-1/3 mx-auto h-48 rounded-full shadow-lg" alt="Testimonial image">
|
||||
</section>
|
||||
</section>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
path1 = "test1.png"
|
||||
generate_screenshot(html, path1)
|
||||
# print(rendered_score(html_ref, html_gen, [mae_score, CLIPScorer().score]))
|
||||
# generate_screenshot(html_ref, "test.png")
|
||||
@@ -0,0 +1,4 @@
|
||||
from .grpo_trainer import VLMGRPOTrainer
|
||||
from .grpo_config import GRPOConfig
|
||||
|
||||
__all__ = ["VLMGRPOTrainer"]
|
||||
@@ -0,0 +1,350 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class GRPOConfig(TrainingArguments):
|
||||
r"""
|
||||
Configuration class for the [`GRPOTrainer`].
|
||||
|
||||
Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
|
||||
[`~transformers.TrainingArguments`] documentation.
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
> Parameters that control the model and reference model
|
||||
|
||||
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
||||
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
||||
argument of the [`GRPOTrainer`] is provided as a string.
|
||||
|
||||
> Parameters that control the data preprocessing
|
||||
|
||||
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
|
||||
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
|
||||
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
||||
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
|
||||
num_generations (`int` or `None`, *optional*, defaults to `8`):
|
||||
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
|
||||
must be divisible by this value.
|
||||
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
|
||||
Maximum length of the generated completion.
|
||||
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
||||
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
||||
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
||||
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
|
||||
with vLLM generation.
|
||||
|
||||
> Parameters that control generation
|
||||
|
||||
temperature (`float`, defaults to `0.9`):
|
||||
Temperature for sampling. The higher the temperature, the more random the completions.
|
||||
top_p (`float`, *optional*, defaults to `1.0`):
|
||||
Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to
|
||||
`1.0` to consider all tokens.
|
||||
top_k (`int` or `None`, *optional*, defaults to `50`):
|
||||
Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is
|
||||
disabled.
|
||||
min_p (`float` or `None`, *optional*, defaults to `None`):
|
||||
Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
|
||||
value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
|
||||
repetition_penalty (`float`, *optional*, defaults to `1.0`):
|
||||
Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
|
||||
Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
|
||||
tokens.
|
||||
cache_implementation (`str` or `None`, *optional*, defaults to `None`):
|
||||
Implementation of the cache method for faster generation when use_vllm is set to False.
|
||||
|
||||
> Parameters that control generation acceleration powered by vLLM
|
||||
|
||||
use_vllm (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
|
||||
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
|
||||
vllm_device (`str`, *optional*, defaults to `"auto"`):
|
||||
Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
|
||||
automatically select the next available GPU after the last one used for training. This assumes that
|
||||
training has not already occupied all available GPUs. If only one device is available, the device will be
|
||||
shared between both training and vLLM.
|
||||
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
|
||||
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
|
||||
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
|
||||
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
|
||||
during initialization.
|
||||
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
|
||||
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
|
||||
based on the model configuration. Find the supported values in the vLLM documentation.
|
||||
vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
|
||||
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
|
||||
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
|
||||
context size, which might be much larger than the KV cache, leading to inefficiencies.
|
||||
vllm_enable_prefix_caching (`bool`, *optional*, defaults to `True`):
|
||||
Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and the hardware
|
||||
support this feature.
|
||||
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
|
||||
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
|
||||
|
||||
> Parameters that control the training
|
||||
|
||||
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
||||
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
||||
[`~transformers.TrainingArguments`].
|
||||
beta (`float`, *optional*, defaults to `0.04`):
|
||||
KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
|
||||
speed, but may be numerically unstable for long training runs.
|
||||
num_iterations (`int`, *optional*, defaults to `1`):
|
||||
Number of iterations per batch (denoted as μ in the algorithm).
|
||||
epsilon (`float`, *optional*, defaults to `0.2`):
|
||||
Epsilon value for clipping.
|
||||
epsilon_high (`float` or `None`, *optional*, defaults to `None`):
|
||||
Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
|
||||
specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
|
||||
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
|
||||
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
|
||||
weighted equally with weight `1.0`.
|
||||
sync_ref_model (`bool`, *optional*, defaults to `False`):
|
||||
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
|
||||
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
|
||||
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
|
||||
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
|
||||
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
|
||||
between the current policy and the previous reference policy during updates. The reference policy is
|
||||
updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
|
||||
must set `sync_ref_model=True`.
|
||||
ref_model_sync_steps (`int`, *optional*, defaults to `512`):
|
||||
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
|
||||
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
|
||||
set `sync_ref_model=True`.
|
||||
|
||||
> Parameters that control the logging
|
||||
|
||||
log_completions (`bool`, *optional*, defaults to `False`):
|
||||
Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is
|
||||
installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
|
||||
"""
|
||||
|
||||
# Parameters that control the model and reference model
|
||||
model_init_kwargs: Optional[dict] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
|
||||
"argument of the `GRPOTrainer` is provided as a string."
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters that control the data preprocessing
|
||||
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
|
||||
# additional columns to compute the reward
|
||||
remove_unused_columns: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
|
||||
"that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
|
||||
},
|
||||
)
|
||||
max_prompt_length: Optional[int] = field(
|
||||
default=4096,
|
||||
metadata={
|
||||
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
|
||||
},
|
||||
)
|
||||
num_generations: Optional[int] = field(
|
||||
default=8,
|
||||
metadata={
|
||||
"help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
|
||||
"must be divisible by this value."
|
||||
},
|
||||
)
|
||||
max_completion_length: Optional[int] = field(
|
||||
default=256,
|
||||
metadata={"help": "Maximum length of the generated completion."},
|
||||
)
|
||||
ds3_gather_for_generation: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
|
||||
"generation, improving generation speed. However, disabling this option allows training models that "
|
||||
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
|
||||
"is not compatible with vLLM generation."
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters that control generation
|
||||
temperature: float = field(
|
||||
default=0.9,
|
||||
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
|
||||
)
|
||||
top_p: float = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
"help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. "
|
||||
"Set to 1.0 to consider all tokens."
|
||||
},
|
||||
)
|
||||
top_k: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={
|
||||
"help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, "
|
||||
"top-k-filtering is disabled."
|
||||
},
|
||||
)
|
||||
min_p: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Minimum token probability, which will be scaled by the probability of the most likely token. It "
|
||||
"must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range."
|
||||
},
|
||||
)
|
||||
repetition_penalty: float = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
"help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated "
|
||||
"text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model "
|
||||
"to repeat tokens."
|
||||
},
|
||||
)
|
||||
cache_implementation: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."},
|
||||
)
|
||||
|
||||
# Parameters that control generation acceleration powered by vLLM
|
||||
use_vllm: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
|
||||
"unused for training, as vLLM will require one for generation. vLLM must be installed "
|
||||
"(`pip install vllm`)."
|
||||
},
|
||||
)
|
||||
vllm_device: Optional[str] = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
|
||||
"will automatically select the next available GPU after the last one used for training. This assumes "
|
||||
"that training has not already occupied all available GPUs."
|
||||
},
|
||||
)
|
||||
vllm_gpu_memory_utilization: float = field(
|
||||
default=0.9,
|
||||
metadata={
|
||||
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
|
||||
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
|
||||
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
|
||||
"out-of-memory (OOM) errors during initialization."
|
||||
},
|
||||
)
|
||||
vllm_dtype: Optional[str] = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
|
||||
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
||||
},
|
||||
)
|
||||
vllm_max_model_len: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
|
||||
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
|
||||
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
||||
},
|
||||
)
|
||||
vllm_enable_prefix_caching: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and "
|
||||
"the hardware support this feature."
|
||||
},
|
||||
)
|
||||
vllm_guided_decoding_regex: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
|
||||
)
|
||||
|
||||
# Parameters that control the training
|
||||
learning_rate: float = field(
|
||||
default=1e-6,
|
||||
metadata={
|
||||
"help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
|
||||
"`transformers.TrainingArguments`."
|
||||
},
|
||||
)
|
||||
beta: float = field(
|
||||
default=0.04,
|
||||
metadata={
|
||||
"help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
|
||||
"training speed, but may be numerically unstable for long training runs."
|
||||
},
|
||||
)
|
||||
num_iterations: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
|
||||
)
|
||||
epsilon: float = field(
|
||||
default=0.2,
|
||||
metadata={"help": "Epsilon value for clipping."},
|
||||
)
|
||||
epsilon_high: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the "
|
||||
"lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`."
|
||||
},
|
||||
)
|
||||
reward_weights: Optional[list[float]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
|
||||
"rewards are weighted equally with weight `1.0`."
|
||||
},
|
||||
)
|
||||
sync_ref_model: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
|
||||
"steps, using the `ref_model_mixup_alpha` parameter."
|
||||
},
|
||||
)
|
||||
ref_model_mixup_alpha: float = field(
|
||||
default=0.6,
|
||||
metadata={
|
||||
"help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
|
||||
"previous reference policy during updates. The reference policy is updated according to the equation: "
|
||||
"`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
|
||||
},
|
||||
)
|
||||
ref_model_sync_steps: int = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
|
||||
"synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters that control the logging
|
||||
log_completions: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is "
|
||||
"installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,864 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Optional, Union, Sized
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import transformers
|
||||
from datasets import Dataset, IterableDataset
|
||||
from packaging import version
|
||||
from transformers import (
|
||||
AriaForConditionalGeneration,
|
||||
AriaProcessor,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
is_wandb_available,
|
||||
)
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
||||
from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
|
||||
from trl.trainer.grpo_config import GRPOConfig
|
||||
from trl.trainer.utils import generate_model_card, get_comet_experiment_url
|
||||
# from trl import GRPOTrainer
|
||||
|
||||
from accelerate.utils import is_peft_model, set_seed
|
||||
import PIL.Image
|
||||
|
||||
import copy
|
||||
from torch.utils.data import Sampler
|
||||
import warnings
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftConfig, get_peft_model
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
from open_r1.vlm_modules.vlm_module import VLMBaseModule
|
||||
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
|
||||
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
||||
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
||||
|
||||
|
||||
class RepeatRandomSampler(Sampler):
|
||||
"""
|
||||
Sampler that repeats the indices of a dataset in a structured manner.
|
||||
|
||||
Args:
|
||||
data_source (`Sized`):
|
||||
Dataset to sample from.
|
||||
mini_repeat_count (`int`):
|
||||
Number of times to repeat each index per batch.
|
||||
batch_size (`int`, *optional*, defaults to `1`):
|
||||
Number of unique indices per batch.
|
||||
repeat_count (`int`, *optional*, defaults to `1`):
|
||||
Number of times to repeat the full sampling process.
|
||||
seed (`int` or `None`, *optional*, defaults to `None`):
|
||||
Random seed for reproducibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_source: Sized,
|
||||
mini_repeat_count: int,
|
||||
batch_size: int = 1,
|
||||
repeat_count: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
self.data_source = data_source
|
||||
self.mini_repeat_count = mini_repeat_count
|
||||
self.batch_size = batch_size
|
||||
self.repeat_count = repeat_count
|
||||
self.num_samples = len(data_source)
|
||||
self.seed = seed
|
||||
self.generator = torch.Generator()
|
||||
if seed is not None:
|
||||
self.generator.manual_seed(seed)
|
||||
|
||||
def __iter__(self):
|
||||
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
|
||||
indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
|
||||
indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
|
||||
|
||||
for chunk in indexes:
|
||||
for _ in range(self.repeat_count):
|
||||
for index in chunk:
|
||||
for _ in range(self.mini_repeat_count):
|
||||
yield index
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples * self.mini_repeat_count * self.repeat_count
|
||||
|
||||
|
||||
class VLMGRPOTrainer(Trainer):
|
||||
"""
|
||||
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
|
||||
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs="weqweasdas/RM-Gemma-2B",
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Args:
|
||||
model (`Union[str, PreTrainedModel]`):
|
||||
Model to be trained. Can be either:
|
||||
|
||||
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
||||
a path to a *directory* containing model weights saved using
|
||||
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
||||
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
||||
in `args.model_init_kwargs`.
|
||||
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
||||
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
|
||||
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
|
||||
functions with the prompts and completions and sum the rewards. Can be either:
|
||||
|
||||
- A single reward function, such as:
|
||||
- A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
||||
path to a *directory* containing model weights saved using
|
||||
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
|
||||
using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
|
||||
keyword arguments in `args.model_init_kwargs`.
|
||||
- A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
|
||||
- A custom reward function: The function is provided with the prompts and the generated completions,
|
||||
plus any additional columns in the dataset. It should return a list of rewards. For more details, see
|
||||
[Using a custom reward function](#using-a-custom-reward-function).
|
||||
- A list of reward functions, where each item can independently be any of the above types. Mixing different
|
||||
types within the list (e.g., a string model ID and a custom reward function) is allowed.
|
||||
args ([`GRPOConfig`], *optional*, defaults to `None`):
|
||||
Configuration for this trainer. If `None`, a default configuration is used.
|
||||
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
||||
Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
|
||||
ignored. The format of the samples can be either:
|
||||
|
||||
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
||||
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
||||
and content).
|
||||
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
||||
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
||||
Processing class used to process the data. The padding side must be set to "left". If `None`, the
|
||||
processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
|
||||
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
|
||||
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
|
||||
|
||||
- A single processing class: Used when `reward_funcs` contains only one reward function.
|
||||
- A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
|
||||
If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
|
||||
`None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
|
||||
For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
|
||||
the corresponding entries in `reward_processing_classes` are ignored.
|
||||
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
|
||||
List of callbacks to customize the training loop. Will add those to the list of default callbacks
|
||||
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
||||
|
||||
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
||||
method.
|
||||
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
|
||||
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
|
||||
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
|
||||
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
||||
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[str, PreTrainedModel],
|
||||
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
||||
args: GRPOConfig = None,
|
||||
vlm_module: VLMBaseModule = None,
|
||||
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
||||
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
||||
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
||||
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
||||
peft_config: Optional["PeftConfig"] = None,
|
||||
freeze_vision_modules: Optional[bool] = False,
|
||||
attn_implementation: str = "flash_attention_2",
|
||||
torch_dtype: str = "bfloat16",
|
||||
**kwargs,
|
||||
):
|
||||
# Args
|
||||
if args is None:
|
||||
model_name = model if isinstance(model, str) else model.config._name_or_path
|
||||
model_name = model_name.split("/")[-1]
|
||||
args = GRPOConfig(f"{model_name}-GRPO")
|
||||
|
||||
self.vlm_module = vlm_module
|
||||
|
||||
# Models
|
||||
# Trained model
|
||||
model_init_kwargs = args.model_init_kwargs or {}
|
||||
# FIXME
|
||||
# Remember to modify it in the invernvl
|
||||
model_init_kwargs["attn_implementation"] = attn_implementation
|
||||
if model_init_kwargs.get("torch_dtype") is None:
|
||||
model_init_kwargs["torch_dtype"] = torch_dtype
|
||||
|
||||
assert isinstance(model, str), "model must be a string in the current implementation"
|
||||
model_id = model
|
||||
torch_dtype = model_init_kwargs.get("torch_dtype")
|
||||
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
||||
pass # torch_dtype is already a torch.dtype or "auto" or None
|
||||
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
||||
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
||||
)
|
||||
# Disable caching if gradient checkpointing is enabled (not supported)
|
||||
model_init_kwargs["use_cache"] = (
|
||||
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
|
||||
)
|
||||
model_cls = self.vlm_module.get_model_class(model_id, model_init_kwargs)
|
||||
model = model_cls.from_pretrained(model_id, **model_init_kwargs)
|
||||
|
||||
# LoRA
|
||||
self.vision_modules_keywords = self.vlm_module.get_vision_modules_keywords()
|
||||
if peft_config is not None:
|
||||
print("Applying LoRA...")
|
||||
def find_all_linear_names(model, multimodal_keywords):
|
||||
cls = torch.nn.Linear
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
# LoRA is not applied to the vision modules
|
||||
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
|
||||
continue
|
||||
if isinstance(module, cls):
|
||||
lora_module_names.add(name)
|
||||
for m in lora_module_names: # needed for 16-bit
|
||||
if "embed_tokens" in m:
|
||||
lora_module_names.remove(m)
|
||||
return list(lora_module_names)
|
||||
target_modules = find_all_linear_names(model, self.vision_modules_keywords)
|
||||
peft_config.target_modules = target_modules
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
# Freeze vision modules
|
||||
if freeze_vision_modules:
|
||||
print("Freezing vision modules...")
|
||||
for n, p in model.named_parameters():
|
||||
if any(keyword in n for keyword in self.vision_modules_keywords):
|
||||
p.requires_grad = False
|
||||
# Compute the number of trainable parameters and print the parameter that is trainable
|
||||
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
||||
total_params = sum(p.numel() for p in trainable_params)
|
||||
# for n, p in model.named_parameters():
|
||||
# if p.requires_grad:
|
||||
# print(n, p.shape)
|
||||
print(f"Total trainable parameters: {total_params}")
|
||||
|
||||
# Enable gradient checkpointing if requested
|
||||
if args.gradient_checkpointing:
|
||||
model = self._enable_gradient_checkpointing(model, args)
|
||||
|
||||
# Reference model
|
||||
self.beta = args.beta
|
||||
if self.beta == 0.0:
|
||||
# If beta is 0.0, the reference model is not needed
|
||||
self.ref_model = None
|
||||
elif is_deepspeed_zero3_enabled():
|
||||
self.ref_model = model_cls.from_pretrained(model_id, **model_init_kwargs)
|
||||
elif is_peft_model(model):
|
||||
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
||||
# to revert to the initial model.
|
||||
self.ref_model = None
|
||||
else:
|
||||
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
||||
self.ref_model = create_reference_model(model)
|
||||
|
||||
# Processing class
|
||||
if processing_class is None:
|
||||
processing_cls = self.vlm_module.get_processing_class()
|
||||
processing_class = processing_cls.from_pretrained(model_id, trust_remote_code=model_init_kwargs.get("trust_remote_code", None))
|
||||
for component, processing_keyword in self.vlm_module.get_custom_processing_keywords():
|
||||
if processing_keyword in kwargs:
|
||||
# If we cannot find component in processing_class, return the processing_class itself
|
||||
processing_component = getattr(processing_class, component, processing_class)
|
||||
setattr(processing_component, processing_keyword, kwargs[processing_keyword])
|
||||
if getattr(processing_class, "tokenizer", None) is not None:
|
||||
pad_token_id = processing_class.tokenizer.pad_token_id
|
||||
processing_class.pad_token_id = pad_token_id
|
||||
processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
|
||||
else:
|
||||
assert isinstance(processing_class, PreTrainedTokenizerBase), "processing_class must be an instance of PreTrainedTokenizerBase if it has no tokenizer attribute"
|
||||
pad_token_id = processing_class.pad_token_id
|
||||
|
||||
self.vlm_module.post_model_init(model, processing_class)
|
||||
self.vlm_module.post_model_init(self.ref_model, processing_class)
|
||||
|
||||
# Reward functions
|
||||
if not isinstance(reward_funcs, list):
|
||||
reward_funcs = [reward_funcs]
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
if isinstance(reward_func, str):
|
||||
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
||||
reward_func, num_labels=1, **model_init_kwargs
|
||||
)
|
||||
self.reward_funcs = reward_funcs
|
||||
|
||||
# Reward processing class
|
||||
if reward_processing_classes is None:
|
||||
reward_processing_classes = [None] * len(reward_funcs)
|
||||
elif not isinstance(reward_processing_classes, list):
|
||||
reward_processing_classes = [reward_processing_classes]
|
||||
else:
|
||||
if len(reward_processing_classes) != len(reward_funcs):
|
||||
raise ValueError("The number of reward processing classes must match the number of reward functions.")
|
||||
|
||||
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
if reward_processing_class is None:
|
||||
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
|
||||
if reward_processing_class.pad_token_id is None:
|
||||
reward_processing_class.pad_token = reward_processing_class.eos_token
|
||||
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
||||
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
||||
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
||||
reward_processing_classes[i] = reward_processing_class
|
||||
self.reward_processing_classes = reward_processing_classes
|
||||
|
||||
# Data collator
|
||||
def data_collator(features): # No data collation is needed in GRPO
|
||||
return features
|
||||
|
||||
# Training arguments
|
||||
self.max_prompt_length = args.max_prompt_length
|
||||
self.max_prompt_length = None
|
||||
if args.max_prompt_length is not None:
|
||||
warnings.warn("Setting max_prompt_length is currently not supported, it has been set to None")
|
||||
|
||||
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
|
||||
self.num_generations = args.num_generations # = G in the GRPO paper
|
||||
self.generation_config = GenerationConfig(
|
||||
max_new_tokens=self.max_completion_length,
|
||||
do_sample=True,
|
||||
temperature=1,
|
||||
pad_token_id=pad_token_id,
|
||||
)
|
||||
if hasattr(self.vlm_module, "get_eos_token_id"): # For InternVL
|
||||
self.generation_config.eos_token_id = self.vlm_module.get_eos_token_id(processing_class)
|
||||
self.beta = args.beta
|
||||
self.epsilon_low = args.epsilon
|
||||
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
|
||||
|
||||
# Multi-step
|
||||
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
|
||||
# Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle
|
||||
self._step = 0
|
||||
# Buffer the batch to reuse generated outputs across multiple updates
|
||||
self._buffered_inputs = [None] * args.gradient_accumulation_steps
|
||||
|
||||
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
||||
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
||||
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
||||
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
||||
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
||||
# This acts as a flag to indicate that the warning has already been issued.
|
||||
model.warnings_issued["estimate_tokens"] = True
|
||||
|
||||
# Initialize the metrics
|
||||
self._metrics = defaultdict(list)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
)
|
||||
|
||||
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
||||
num_processes = self.accelerator.num_processes
|
||||
global_batch_size = args.per_device_train_batch_size * num_processes
|
||||
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
||||
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
|
||||
f"batch size, the valid values for the number of generations are: {possible_values}."
|
||||
)
|
||||
if self.args.eval_strategy != "no":
|
||||
global_batch_size = args.per_device_eval_batch_size * num_processes
|
||||
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
||||
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
|
||||
f"eval batch size, the valid values for the number of generations are: {possible_values}."
|
||||
)
|
||||
|
||||
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
|
||||
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
|
||||
# it's safer to set it in all cases.
|
||||
set_seed(args.seed, device_specific=True)
|
||||
|
||||
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
||||
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
||||
# self.model_accepts_loss_kwargs to False to enable scaling.
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
if self.ref_model is not None:
|
||||
# if self.is_deepspeed_enabled:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
||||
|
||||
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
|
||||
"""Enables gradient checkpointing for the model."""
|
||||
# Ensure use_cache is disabled
|
||||
model.config.use_cache = False
|
||||
|
||||
# Enable gradient checkpointing on the base model for PEFT
|
||||
if is_peft_model(model):
|
||||
model.base_model.gradient_checkpointing_enable()
|
||||
# Enable gradient checkpointing for non-PEFT models
|
||||
else:
|
||||
model.gradient_checkpointing_enable()
|
||||
try:
|
||||
# For InternVL; these operations are copied from the original training script of InternVL
|
||||
model.language_model.config.use_cache = False
|
||||
model.vision_model.gradient_checkpointing = True
|
||||
model.vision_model.encoder.gradient_checkpointing = True
|
||||
model.language_model._set_gradient_checkpointing()
|
||||
# This line is necessary, otherwise the `model.gradient_checkpointing_enable()` will be executed during the training process, leading to an error since InternVL does not support this operation.
|
||||
args.gradient_checkpointing = False
|
||||
except:
|
||||
pass
|
||||
|
||||
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
||||
use_reentrant = (
|
||||
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
||||
)
|
||||
|
||||
if use_reentrant:
|
||||
model.enable_input_require_grads()
|
||||
|
||||
return model
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
||||
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
||||
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
||||
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
||||
if self._signature_columns is None:
|
||||
self._signature_columns = ["prompt"]
|
||||
|
||||
|
||||
# Get the per-token log probabilities for the completions for the model and the reference model
|
||||
def _get_per_token_logps(self, model, input_ids, attention_mask, **custom_multimodal_inputs):
|
||||
logits = model(input_ids=input_ids, attention_mask=attention_mask, **custom_multimodal_inputs).logits # (B, L, V)
|
||||
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
||||
input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
|
||||
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
|
||||
per_token_logps = []
|
||||
for logits_row, input_ids_row in zip(logits, input_ids):
|
||||
log_probs = logits_row.log_softmax(dim=-1)
|
||||
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
|
||||
per_token_logps.append(token_log_prob)
|
||||
return torch.stack(per_token_logps)
|
||||
|
||||
|
||||
def _prepare_inputs(self, inputs):
|
||||
# Simple pass-through, just like original
|
||||
return inputs
|
||||
|
||||
def _get_key_from_inputs(self, x, key):
|
||||
ele = x.get(key, None)
|
||||
assert ele is not None, f"The key {key} is not found in the input"
|
||||
if isinstance(ele, list):
|
||||
return [e for e in ele]
|
||||
else:
|
||||
return [ele]
|
||||
|
||||
def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor, Any]], model) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
device = self.accelerator.device
|
||||
prompts = [x["prompt"] for x in inputs]
|
||||
prompts_text = self.vlm_module.prepare_prompt(self.processing_class, inputs)
|
||||
# Handle both pre-loaded images and image paths
|
||||
images = []
|
||||
for x in inputs:
|
||||
if "image" in x:
|
||||
imgs = self._get_key_from_inputs(x, "image")
|
||||
elif "image_path" in x and x["image_path"] is not None:
|
||||
imgs = [PIL.Image.open(p) for p in self._get_key_from_inputs(x, "image_path")]
|
||||
else:
|
||||
imgs = []
|
||||
|
||||
for img in imgs:
|
||||
try:
|
||||
# Ensure minimum dimensions of 28 pixels
|
||||
w, h = img.size
|
||||
if w < 28 or h < 28:
|
||||
# Calculate new dimensions maintaining aspect ratio
|
||||
if w < h:
|
||||
new_w = 28
|
||||
new_h = int(h * (28/w))
|
||||
else:
|
||||
new_h = 28
|
||||
new_w = int(w * (28/h))
|
||||
img = img.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS)
|
||||
except:
|
||||
pass
|
||||
images.append(img)
|
||||
|
||||
|
||||
prompt_inputs = self.vlm_module.prepare_model_inputs(
|
||||
self.processing_class,
|
||||
prompts_text,
|
||||
images,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
add_special_tokens=False,
|
||||
)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
|
||||
|
||||
# max_prompt_length is not supported yet
|
||||
# if self.max_prompt_length is not None:
|
||||
# prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
||||
# prompt_inputs["input_ids"] = prompt_ids
|
||||
# prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
||||
# prompt_inputs["attention_mask"] = prompt_mask
|
||||
|
||||
# Generate completions
|
||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
||||
generate_returned_result = unwrapped_model.generate(
|
||||
**{k: v for k, v in prompt_inputs.items() if k not in self.vlm_module.get_non_generate_params()},
|
||||
generation_config=self.generation_config
|
||||
)
|
||||
prompt_length = prompt_ids.size(1)
|
||||
if not self.vlm_module.is_embeds_input():
|
||||
prompt_completion_ids = generate_returned_result
|
||||
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
||||
completion_ids = prompt_completion_ids[:, prompt_length:]
|
||||
else:
|
||||
# In this case, the input of the LLM backbone is the embedding of the combination of the image and text prompt
|
||||
# So the returned result of the `generate` method only contains the completion ids
|
||||
completion_ids = generate_returned_result
|
||||
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
||||
|
||||
# Mask everything after the first EOS token
|
||||
is_eos = completion_ids == self.processing_class.eos_token_id
|
||||
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
||||
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||||
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
||||
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
||||
|
||||
# Concatenate prompt_mask with completion_mask for logit computation
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
||||
|
||||
# Get the multimodal inputs
|
||||
multimodal_keywords = self.vlm_module.get_custom_multimodal_keywords()
|
||||
multimodal_inputs = {k: prompt_inputs[k] if k in prompt_inputs else None for k in multimodal_keywords}
|
||||
with torch.no_grad():
|
||||
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its
|
||||
# computation here, and use per_token_logps.detach() instead.
|
||||
if self.num_iterations > 1:
|
||||
old_per_token_logps = self._get_per_token_logps(
|
||||
model, prompt_completion_ids, attention_mask, **multimodal_inputs
|
||||
)
|
||||
old_per_token_logps = old_per_token_logps[:, prompt_length - 1:]
|
||||
else:
|
||||
old_per_token_logps = None
|
||||
|
||||
if self.beta == 0.0:
|
||||
ref_per_token_logps = None
|
||||
elif self.ref_model is not None:
|
||||
ref_per_token_logps = self._get_per_token_logps(
|
||||
self.ref_model, prompt_completion_ids, attention_mask, **multimodal_inputs
|
||||
)
|
||||
else:
|
||||
with self.accelerator.unwrap_model(model).disable_adapter():
|
||||
ref_per_token_logps = self._get_per_token_logps(
|
||||
model, prompt_completion_ids, attention_mask, **multimodal_inputs
|
||||
)
|
||||
if ref_per_token_logps is not None:
|
||||
ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1:]
|
||||
|
||||
# Decode the generated completions
|
||||
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
||||
if is_conversational(inputs[0]):
|
||||
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
|
||||
|
||||
# Compute the rewards
|
||||
# No need to duplicate prompts as we're not generating multiple completions per prompt
|
||||
|
||||
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
|
||||
for i, (reward_func, reward_processing_class) in enumerate(
|
||||
zip(self.reward_funcs, self.reward_processing_classes)
|
||||
):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
if is_conversational(inputs[0]):
|
||||
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
||||
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
||||
else:
|
||||
texts = [p + c for p, c in zip(prompts, completions)]
|
||||
reward_inputs = reward_processing_class(
|
||||
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
|
||||
)
|
||||
reward_inputs = super()._prepare_inputs(reward_inputs)
|
||||
with torch.inference_mode():
|
||||
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
|
||||
else:
|
||||
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
||||
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
|
||||
for key in reward_kwargs:
|
||||
for example in inputs:
|
||||
# No need to duplicate prompts as we're not generating multiple completions per prompt
|
||||
# reward_kwargs[key].extend([example[key]] * self.num_generations)
|
||||
reward_kwargs[key].extend([example[key]])
|
||||
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
|
||||
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
|
||||
|
||||
# Gather rewards across processes
|
||||
rewards_per_func = self.accelerator.gather(rewards_per_func)
|
||||
|
||||
# Sum the rewards from all reward functions
|
||||
rewards = rewards_per_func.sum(dim=1)
|
||||
|
||||
# Compute grouped-wise rewards
|
||||
# Each group consists of num_generations completions for the same prompt
|
||||
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
||||
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
||||
|
||||
# Normalize the rewards to compute the advantages
|
||||
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
||||
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
||||
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
||||
|
||||
# Get only the local slice of advantages
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts),
|
||||
(self.accelerator.process_index + 1) * len(prompts),
|
||||
)
|
||||
advantages = advantages[process_slice]
|
||||
|
||||
# Log the metrics
|
||||
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
||||
self._metrics["completion_length"].append(completion_length)
|
||||
|
||||
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
||||
else:
|
||||
reward_func_name = reward_func.__name__
|
||||
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
||||
|
||||
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
|
||||
|
||||
self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
|
||||
|
||||
return {
|
||||
"prompt_ids": prompt_ids,
|
||||
"prompt_mask": prompt_mask,
|
||||
"completion_ids": completion_ids,
|
||||
"completion_mask": completion_mask,
|
||||
"old_per_token_logps": old_per_token_logps,
|
||||
"ref_per_token_logps": ref_per_token_logps,
|
||||
"advantages": advantages,
|
||||
"multimodal_inputs": multimodal_inputs
|
||||
}
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
if return_outputs:
|
||||
raise ValueError("The GRPOTrainer does not support returning outputs")
|
||||
|
||||
# Check if we need to generate new completions or use buffered ones
|
||||
if self.state.global_step % self.num_iterations == 0:
|
||||
inputs = self._generate_and_score_completions(inputs, model)
|
||||
self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
|
||||
else:
|
||||
inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
|
||||
self._step += 1
|
||||
|
||||
# Get the prepared inputs
|
||||
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
||||
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
||||
multimodal_inputs = inputs["multimodal_inputs"]
|
||||
|
||||
# Concatenate for full sequence
|
||||
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
||||
|
||||
# Get the current policy's log probabilities
|
||||
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, **multimodal_inputs)
|
||||
# Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
|
||||
per_token_logps = per_token_logps[:, prompt_ids.size(1) - 1:]
|
||||
|
||||
# Get the advantages from inputs
|
||||
advantages = inputs["advantages"]
|
||||
|
||||
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation
|
||||
# and use per_token_logps.detach() instead
|
||||
old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
|
||||
|
||||
# Compute the policy ratio and clipped version
|
||||
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
|
||||
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
|
||||
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
||||
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
||||
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
||||
|
||||
# Add KL penalty if beta > 0
|
||||
if self.beta > 0:
|
||||
ref_per_token_logps = inputs["ref_per_token_logps"]
|
||||
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
||||
per_token_loss = per_token_loss + self.beta * per_token_kl
|
||||
|
||||
# Log KL divergence
|
||||
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
||||
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
||||
|
||||
# Compute final loss
|
||||
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
||||
|
||||
# Log clip ratio
|
||||
is_clipped = (per_token_loss1 < per_token_loss2).float()
|
||||
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
|
||||
self._metrics["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
|
||||
|
||||
return loss
|
||||
|
||||
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
||||
logs = {**logs, **metrics}
|
||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||
super().log(logs, start_time)
|
||||
else: # transformers<=4.46
|
||||
super().log(logs)
|
||||
self._metrics.clear()
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, list[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
||||
Args:
|
||||
model_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the model.
|
||||
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the dataset used for training.
|
||||
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
tags = tags or []
|
||||
if isinstance(tags, str):
|
||||
tags = [tags]
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.append("unsloth")
|
||||
|
||||
citation = textwrap.dedent(
|
||||
"""\
|
||||
@article{zhihong2024deepseekmath,
|
||||
title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
|
||||
author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
|
||||
year = 2024,
|
||||
eprint = {arXiv:2402.03300},
|
||||
"""
|
||||
)
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
||||
comet_url=get_comet_experiment_url(),
|
||||
trainer_name="GRPO",
|
||||
trainer_citation=citation,
|
||||
paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
|
||||
paper_id="2402.03300",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
||||
|
||||
def _get_train_sampler(self) -> Sampler:
|
||||
"""Returns a sampler that ensures proper data sampling for GRPO training."""
|
||||
effective_batch_size = (
|
||||
self.args.per_device_train_batch_size
|
||||
* self.accelerator.num_processes
|
||||
* self.args.gradient_accumulation_steps
|
||||
)
|
||||
|
||||
return RepeatRandomSampler(
|
||||
data_source=self.train_dataset,
|
||||
mini_repeat_count=self.num_generations,
|
||||
batch_size=effective_batch_size // self.num_generations,
|
||||
repeat_count=self.num_iterations,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
|
||||
def _get_eval_sampler(self, eval_dataset) -> Sampler:
|
||||
"""Returns a sampler for evaluation."""
|
||||
return RepeatRandomSampler(
|
||||
data_source=eval_dataset,
|
||||
mini_repeat_count=self.num_generations,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
@@ -0,0 +1,825 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from accelerate.utils.other import is_compiled_module
|
||||
from accelerate.utils import broadcast_object_list, gather, gather_object
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import transformers
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
from datasets import Dataset, IterableDataset
|
||||
from packaging import version
|
||||
from transformers import (
|
||||
AriaForConditionalGeneration,
|
||||
AriaProcessor,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
is_wandb_available,
|
||||
)
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from trl.data_utils import (
|
||||
apply_chat_template,
|
||||
is_conversational,
|
||||
maybe_apply_chat_template,
|
||||
)
|
||||
from trl.import_utils import is_vllm_available
|
||||
|
||||
from trl.models import (
|
||||
create_reference_model,
|
||||
prepare_deepspeed,
|
||||
unwrap_model_for_generation,
|
||||
)
|
||||
from trl.trainer.grpo_config import GRPOConfig
|
||||
from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
|
||||
from trl import GRPOTrainer
|
||||
|
||||
import copy
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftConfig, get_peft_model
|
||||
|
||||
if is_vllm_available():
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
|
||||
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
||||
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
||||
|
||||
|
||||
class RepeatRandomSampler(Sampler):
|
||||
"""
|
||||
Sampler that repeats the indices of a dataset N times.
|
||||
|
||||
Args:
|
||||
data_source (`Sized`):
|
||||
Dataset to sample from.
|
||||
repeat_count (`int`):
|
||||
Number of times to repeat each index.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
|
||||
>>> list(sampler)
|
||||
[2, 2, 0, 0, 3, 3, 1, 1]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, data_source, repeat_count: int):
|
||||
self.data_source = data_source
|
||||
self.repeat_count = repeat_count
|
||||
self.num_samples = len(data_source)
|
||||
|
||||
def __iter__(self):
|
||||
indexes = [
|
||||
idx
|
||||
for idx in torch.randperm(self.num_samples).tolist()
|
||||
for _ in range(self.repeat_count)
|
||||
]
|
||||
return iter(indexes)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples * self.repeat_count
|
||||
|
||||
|
||||
class Qwen2VLGRPOVLLMTrainer(Trainer):
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[str, PreTrainedModel],
|
||||
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
||||
args: GRPOConfig = None,
|
||||
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
||||
eval_dataset: Optional[
|
||||
Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
|
||||
] = None,
|
||||
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
||||
reward_processing_classes: Optional[
|
||||
Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
|
||||
] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[
|
||||
Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
|
||||
] = (None, None),
|
||||
peft_config: Optional["PeftConfig"] = None,
|
||||
# qwen2-vl related params
|
||||
max_pixels: Optional[int] = 12845056,
|
||||
min_pixels: Optional[int] = 3136,
|
||||
attn_implementation: str = "flash_attention_2",
|
||||
):
|
||||
|
||||
# Args
|
||||
if args is None:
|
||||
model_name = model if isinstance(model, str) else model.config._name_or_path
|
||||
model_name = model_name.split("/")[-1]
|
||||
args = GRPOConfig(f"{model_name}-GRPO")
|
||||
|
||||
# Models
|
||||
# Trained model
|
||||
model_init_kwargs = args.model_init_kwargs or {}
|
||||
model_init_kwargs["attn_implementation"] = attn_implementation
|
||||
if isinstance(model, str):
|
||||
model_id = model
|
||||
torch_dtype = model_init_kwargs.get("torch_dtype")
|
||||
if (
|
||||
isinstance(torch_dtype, torch.dtype)
|
||||
or torch_dtype == "auto"
|
||||
or torch_dtype is None
|
||||
):
|
||||
pass # torch_dtype is already a torch.dtype or "auto" or None
|
||||
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
model_init_kwargs["torch_dtype"] = torch_dtype
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
||||
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
||||
)
|
||||
# Disable caching if gradient checkpointing is enabled (not supported)
|
||||
model_init_kwargs["use_cache"] = (
|
||||
False
|
||||
if args.gradient_checkpointing
|
||||
else model_init_kwargs.get("use_cache")
|
||||
)
|
||||
if "Qwen2-VL" in model_id:
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model, **model_init_kwargs
|
||||
)
|
||||
elif "Qwen2.5-VL" in model_id:
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
|
||||
elif "Aria" in model_id:
|
||||
model_init_kwargs.pop("use_cache")
|
||||
model = AriaForConditionalGeneration.from_pretrained(
|
||||
model, **model_init_kwargs
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
||||
else:
|
||||
model_id = model.config._name_or_path
|
||||
if args.model_init_kwargs is not None:
|
||||
raise ValueError(
|
||||
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
|
||||
"This argument can only be used when the `model` argument is a string."
|
||||
)
|
||||
|
||||
if peft_config is not None:
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
# Reference model
|
||||
if is_deepspeed_zero3_enabled():
|
||||
if "Qwen2-VL" in model_id:
|
||||
self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_id, **model_init_kwargs
|
||||
)
|
||||
elif "Aria" in model_id:
|
||||
self.ref_model = AriaForConditionalGeneration.from_pretrained(
|
||||
model_id, **model_init_kwargs
|
||||
)
|
||||
else:
|
||||
self.ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, **model_init_kwargs
|
||||
)
|
||||
elif peft_config is None:
|
||||
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
||||
self.ref_model = create_reference_model(model)
|
||||
else:
|
||||
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
||||
# to revert to the initial model.
|
||||
self.ref_model = None
|
||||
|
||||
# Processing class
|
||||
if processing_class is None:
|
||||
if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id or "Aria" in model_id:
|
||||
processing_class = AutoProcessor.from_pretrained(model_id)
|
||||
pad_token_id = processing_class.tokenizer.pad_token_id
|
||||
processing_class.pad_token_id = pad_token_id
|
||||
processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
|
||||
if "Qwen" in model_id or "Qwen2.5-VL" in model_id:
|
||||
processing_class.image_processor.max_pixels = max_pixels
|
||||
processing_class.image_processor.min_pixels = min_pixels
|
||||
else:
|
||||
processing_class = AutoTokenizer.from_pretrained(
|
||||
model.config._name_or_path, padding_side="left"
|
||||
)
|
||||
pad_token_id = processing_class.pad_token_id
|
||||
|
||||
# Reward functions
|
||||
if not isinstance(reward_funcs, list):
|
||||
reward_funcs = [reward_funcs]
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
if isinstance(reward_func, str):
|
||||
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
||||
reward_func, num_labels=1, **model_init_kwargs
|
||||
)
|
||||
self.reward_funcs = reward_funcs
|
||||
|
||||
# Reward processing class
|
||||
if reward_processing_classes is None:
|
||||
reward_processing_classes = [None] * len(reward_funcs)
|
||||
elif not isinstance(reward_processing_classes, list):
|
||||
reward_processing_classes = [reward_processing_classes]
|
||||
else:
|
||||
if len(reward_processing_classes) != len(reward_funcs):
|
||||
raise ValueError(
|
||||
"The number of reward processing classes must match the number of reward functions."
|
||||
)
|
||||
|
||||
for i, (reward_processing_class, reward_func) in enumerate(
|
||||
zip(reward_processing_classes, reward_funcs)
|
||||
):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
if reward_processing_class is None:
|
||||
reward_processing_class = AutoTokenizer.from_pretrained(
|
||||
reward_func.config._name_or_path
|
||||
)
|
||||
if reward_processing_class.pad_token_id is None:
|
||||
reward_processing_class.pad_token = (
|
||||
reward_processing_class.eos_token
|
||||
)
|
||||
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
||||
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
||||
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
||||
reward_processing_classes[i] = reward_processing_class
|
||||
self.reward_processing_classes = reward_processing_classes
|
||||
|
||||
# Data collator
|
||||
def data_collator(features): # No data collation is needed in GRPO
|
||||
return features
|
||||
|
||||
# Training arguments
|
||||
self.max_prompt_length = args.max_prompt_length
|
||||
self.max_completion_length = (
|
||||
args.max_completion_length
|
||||
) # = |o_i| in the GRPO paper
|
||||
self.num_generations = args.num_generations # = G in the GRPO paper
|
||||
self.generation_config = GenerationConfig(
|
||||
max_new_tokens=self.max_completion_length,
|
||||
do_sample=True,
|
||||
temperature=1, # HACK
|
||||
num_return_sequences=self.num_generations,
|
||||
pad_token_id=pad_token_id,
|
||||
)
|
||||
self.beta = args.beta
|
||||
|
||||
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
||||
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
||||
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
||||
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
||||
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
||||
# This acts as a flag to indicate that the warning has already been issued.
|
||||
model.warnings_issued["estimate_tokens"] = True
|
||||
|
||||
# Initialize the metrics
|
||||
self._metrics = defaultdict(list)
|
||||
self.use_vllm = args.use_vllm
|
||||
|
||||
# # rewrite the processing AutoTokenizer -> AutoProcessor
|
||||
# model_id = model if isinstance(model, str) else model.config._name_or_path
|
||||
# if processing_class is None:
|
||||
# if "Qwen2-VL" in model_id or "Aria" in model_id:
|
||||
# processing_class = AutoProcessor.from_pretrained(model_id)
|
||||
# pad_token_id = processing_class.tokenizer.pad_token_id
|
||||
# processing_class.pad_token_id = pad_token_id
|
||||
# processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
|
||||
# if "Qwen2-VL" in model_id:
|
||||
# processing_class.image_processor.max_pixels = max_pixels
|
||||
# processing_class.image_processor.min_pixels = min_pixels
|
||||
# else:
|
||||
# processing_class = AutoTokenizer.from_pretrained(
|
||||
# model.config._name_or_path, padding_side="left"
|
||||
# )
|
||||
# pad_token_id = processing_class.pad_token_id
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
)
|
||||
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
||||
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
||||
# self.model_accepts_loss_kwargs to False to enable scaling.
|
||||
self.model_accepts_loss_kwargs = False
|
||||
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
||||
num_processes = self.accelerator.num_processes
|
||||
global_batch_size = args.per_device_train_batch_size * num_processes
|
||||
possible_values = [
|
||||
n_gen
|
||||
for n_gen in range(2, global_batch_size + 1)
|
||||
if (global_batch_size) % n_gen == 0
|
||||
]
|
||||
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
||||
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
|
||||
f"batch size, the valid values for the number of generations are: {possible_values}."
|
||||
)
|
||||
if self.args.eval_strategy != "no":
|
||||
global_batch_size = args.per_device_eval_batch_size * num_processes
|
||||
possible_values = [
|
||||
n_gen
|
||||
for n_gen in range(2, global_batch_size + 1)
|
||||
if (global_batch_size) % n_gen == 0
|
||||
]
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
||||
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
|
||||
f"eval batch size, the valid values for the number of generations are: {possible_values}."
|
||||
)
|
||||
|
||||
if self.use_vllm:
|
||||
if not is_vllm_available():
|
||||
raise ImportError(
|
||||
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
|
||||
"`pip install vllm` to use it."
|
||||
)
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
vllm_device = self.args.vllm_device
|
||||
if vllm_device == "auto":
|
||||
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
|
||||
# Check that the requested device is available
|
||||
if (
|
||||
vllm_device.split(":")[0] == "cuda"
|
||||
and int(vllm_device.split(":")[1]) >= torch.cuda.device_count()
|
||||
):
|
||||
raise ValueError(
|
||||
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
|
||||
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
|
||||
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
|
||||
f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
|
||||
)
|
||||
# Check that the requested device is not also used for training
|
||||
if vllm_device in {
|
||||
f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
|
||||
}:
|
||||
warnings.warn(
|
||||
f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
|
||||
"behavior. It is recommended to use a dedicated device for vLLM."
|
||||
)
|
||||
# vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
|
||||
# model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
|
||||
# setting (profiling_patch).
|
||||
world_size_patch = patch(
|
||||
"torch.distributed.get_world_size", return_value=1
|
||||
)
|
||||
profiling_patch = patch(
|
||||
"vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
|
||||
return_value=None,
|
||||
)
|
||||
with world_size_patch, profiling_patch:
|
||||
print("vllm is running on: ", vllm_device)
|
||||
self.llm = LLM(
|
||||
model=model.name_or_path,
|
||||
device=vllm_device,
|
||||
gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
|
||||
dtype=torch.bfloat16,
|
||||
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
|
||||
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
||||
# This is particularly useful here because we generate completions from the same prompts.
|
||||
enable_prefix_caching=True,
|
||||
enforce_eager=True,
|
||||
max_model_len=args.max_completion_length,
|
||||
)
|
||||
self.sampling_params = SamplingParams(
|
||||
temperature=args.temperature,
|
||||
max_tokens=self.max_completion_length,
|
||||
)
|
||||
|
||||
self._last_loaded_step = (
|
||||
0 # tag to avoid useless loading during grad accumulation
|
||||
)
|
||||
|
||||
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
|
||||
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
|
||||
# synchronize all processes after vLLM has been fully initialized.
|
||||
self.accelerator.wait_for_everyone()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Qwen2VLGRPOVLLMTrainer only supports vllm generation, please set --use_vllm True"
|
||||
)
|
||||
|
||||
if self.ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(
|
||||
self.ref_model, evaluation_mode=True
|
||||
)
|
||||
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
self.reward_funcs[i] = self.accelerator.prepare_model(
|
||||
reward_func, evaluation_mode=True
|
||||
)
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
||||
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
||||
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
||||
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
||||
if self._signature_columns is None:
|
||||
self._signature_columns = ["prompt"]
|
||||
|
||||
# We need a custom sampler that samples the same prompt multiple times
|
||||
def _get_train_sampler(self):
|
||||
return RepeatRandomSampler(self.train_dataset, self.num_generations)
|
||||
|
||||
# Get the per-token log probabilities for the completions for the model and the reference model
|
||||
def _get_per_token_logps(
|
||||
self,
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
pixel_values,
|
||||
image_grid_thw,
|
||||
logits_to_keep,
|
||||
):
|
||||
pixel_values = pixel_values.to(model.device)
|
||||
image_grid_thw = image_grid_thw.to(device=model.device)
|
||||
logits = model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
).logits # (B, L, V)
|
||||
logits = logits[
|
||||
:, :-1, :
|
||||
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
||||
input_ids = input_ids[
|
||||
:, -logits_to_keep:
|
||||
] # (B, L-1), exclude the first input ID since we don't have logits for it
|
||||
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
|
||||
logits = logits[:, -logits_to_keep:]
|
||||
per_token_logps = []
|
||||
for logits_row, input_ids_row in zip(logits, input_ids):
|
||||
log_probs = logits_row.log_softmax(dim=-1)
|
||||
token_log_prob = torch.gather(
|
||||
log_probs, dim=1, index=input_ids_row.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
per_token_logps.append(token_log_prob)
|
||||
return torch.stack(per_token_logps)
|
||||
|
||||
# Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
|
||||
# Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
|
||||
def _prepare_inputs(
|
||||
self, inputs: dict[str, Union[torch.Tensor, Any]]
|
||||
) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
device = self.accelerator.device
|
||||
prompts = [x["prompt"] for x in inputs]
|
||||
images = [x["image"] for x in inputs]
|
||||
prompts_text = [
|
||||
maybe_apply_chat_template(example, self.processing_class)["prompt"]
|
||||
for example in inputs
|
||||
]
|
||||
prompt_inputs = self.processing_class(
|
||||
# prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
|
||||
text=prompts_text,
|
||||
images=images,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
add_special_tokens=False,
|
||||
)
|
||||
prompt_ids, prompt_mask = (
|
||||
prompt_inputs["input_ids"].to(device),
|
||||
prompt_inputs["attention_mask"].to(device),
|
||||
)
|
||||
if self.max_prompt_length is not None:
|
||||
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
||||
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
||||
|
||||
if self.args.use_vllm:
|
||||
# First, have main process load weights if needed
|
||||
if self.state.global_step != self._last_loaded_step:
|
||||
with unwrap_model_for_generation(
|
||||
self.model,
|
||||
self.accelerator,
|
||||
gather_deepspeed3_params=False, # TODO: fix this, self.args.ds3_gather_for_generation,
|
||||
) as unwrapped_model:
|
||||
if is_compiled_module(unwrapped_model):
|
||||
state_dict = unwrapped_model._orig_mod.state_dict()
|
||||
else:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = (
|
||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
)
|
||||
llm_model.load_weights(state_dict.items())
|
||||
self._last_loaded_step = self.state.global_step
|
||||
|
||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
all_images = gather_object(images)
|
||||
# group into pairs
|
||||
all_multimodal_inputs = [
|
||||
{"prompt": p, "multi_modal_data": {"image": i}}
|
||||
for p, i in zip(all_prompts_text, all_images)
|
||||
]
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
outputs = self.llm.generate(
|
||||
all_multimodal_inputs,
|
||||
sampling_params=self.sampling_params,
|
||||
use_tqdm=False,
|
||||
)
|
||||
completion_ids = [
|
||||
out.token_ids
|
||||
for completions in outputs
|
||||
for out in completions.outputs
|
||||
]
|
||||
else:
|
||||
completion_ids = [None] * len(all_prompts_text)
|
||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts),
|
||||
(self.accelerator.process_index + 1) * len(prompts),
|
||||
)
|
||||
completion_ids = completion_ids[process_slice]
|
||||
|
||||
# Pad the completions, and concatenate them with the prompts
|
||||
completion_ids = [
|
||||
torch.tensor(ids, device=device) for ids in completion_ids
|
||||
]
|
||||
completion_ids = pad(
|
||||
completion_ids, padding_value=self.processing_class.pad_token_id
|
||||
)
|
||||
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
||||
else:
|
||||
raise ValueError("Only vLLM generation is supported in this version ")
|
||||
|
||||
# below are the same with yifan's code
|
||||
# Mask everything after the first EOS token
|
||||
is_eos = completion_ids == self.processing_class.eos_token_id
|
||||
device = self.accelerator.device
|
||||
eos_idx = torch.full(
|
||||
(is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
|
||||
)
|
||||
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||||
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(
|
||||
is_eos.size(0), -1
|
||||
)
|
||||
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
||||
|
||||
# Concatenate prompt_mask with completion_mask for logit computation
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
|
||||
# pixel_values = prompt_inputs["pixel_values"].repeat_interleave(
|
||||
# self.num_generations, dim=0
|
||||
# )
|
||||
|
||||
pixel_values = prompt_inputs["pixel_values"]
|
||||
# [None].repeat_interleave(self.num_generations, dim=0)
|
||||
# pixel_values = pixel_values.view(-1, pixel_values.shape[-1])
|
||||
|
||||
image_grid_thw = prompt_inputs["image_grid_thw"]
|
||||
# .repeat_interleave(
|
||||
# self.num_generations, dim=0
|
||||
# )
|
||||
logits_to_keep = completion_ids.size(1)
|
||||
|
||||
with torch.inference_mode():
|
||||
if self.ref_model is not None:
|
||||
ref_per_token_logps = self._get_per_token_logps(
|
||||
self.ref_model,
|
||||
prompt_completion_ids,
|
||||
attention_mask,
|
||||
pixel_values,
|
||||
image_grid_thw,
|
||||
logits_to_keep,
|
||||
)
|
||||
else:
|
||||
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
||||
ref_per_token_logps = self._get_per_token_logps(
|
||||
self.model,
|
||||
prompt_completion_ids,
|
||||
attention_mask,
|
||||
pixel_values,
|
||||
image_grid_thw,
|
||||
logits_to_keep,
|
||||
)
|
||||
|
||||
# Decode the generated completions
|
||||
completions = self.processing_class.batch_decode(
|
||||
completion_ids, skip_special_tokens=True
|
||||
)
|
||||
if is_conversational(inputs[0]):
|
||||
completions = [
|
||||
[{"role": "assistant", "content": completion}]
|
||||
for completion in completions
|
||||
]
|
||||
|
||||
# Compute the rewards
|
||||
rewards_per_func = torch.zeros(
|
||||
len(prompts), len(self.reward_funcs), device=device
|
||||
)
|
||||
for i, (reward_func, reward_processing_class) in enumerate(
|
||||
zip(self.reward_funcs, self.reward_processing_classes)
|
||||
):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
if is_conversational(inputs[0]):
|
||||
messages = [
|
||||
{"messages": p + c} for p, c in zip(prompts, completions)
|
||||
]
|
||||
texts = [
|
||||
apply_chat_template(x, reward_processing_class)["text"]
|
||||
for x in messages
|
||||
]
|
||||
else:
|
||||
texts = [p + c for p, c in zip(prompts, completions)]
|
||||
reward_inputs = reward_processing_class(
|
||||
texts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="right",
|
||||
add_special_tokens=False,
|
||||
)
|
||||
reward_inputs = super()._prepare_inputs(reward_inputs)
|
||||
with torch.inference_mode():
|
||||
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
|
||||
:, 0
|
||||
] # Shape (B*G,)
|
||||
else:
|
||||
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
||||
reward_kwargs = {
|
||||
key: []
|
||||
for key in inputs[0].keys()
|
||||
if key not in ["prompt", "completion"]
|
||||
}
|
||||
for key in reward_kwargs:
|
||||
for example in inputs:
|
||||
# Repeat each value in the column for `num_generations` times
|
||||
reward_kwargs[key].extend([example[key]] * self.num_generations)
|
||||
output_reward_func = reward_func(
|
||||
prompts=prompts, completions=completions, **reward_kwargs
|
||||
)
|
||||
rewards_per_func[:, i] = torch.tensor(
|
||||
output_reward_func, dtype=torch.float32, device=device
|
||||
)
|
||||
rewards_per_func = gather(rewards_per_func)
|
||||
# Sum the rewards from all reward functions
|
||||
rewards = rewards_per_func.sum(dim=1)
|
||||
|
||||
# Compute grouped-wise rewards
|
||||
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
||||
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
||||
|
||||
# Normalize the rewards to compute the advantages
|
||||
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
|
||||
self.num_generations, dim=0
|
||||
)
|
||||
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
|
||||
self.num_generations, dim=0
|
||||
)
|
||||
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
||||
|
||||
# Slice to keep only the local part of the data
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts),
|
||||
(self.accelerator.process_index + 1) * len(prompts),
|
||||
)
|
||||
advantages = advantages[process_slice]
|
||||
|
||||
# Log the metrics
|
||||
reward_per_func = rewards_per_func.mean(0)
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(
|
||||
reward_func, nn.Module
|
||||
): # Module instead of PretrainedModel for compat with compiled models
|
||||
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
||||
else:
|
||||
reward_func_name = reward_func.__name__
|
||||
self._metrics[f"rewards/{reward_func_name}"].append(
|
||||
reward_per_func[i].item()
|
||||
)
|
||||
|
||||
self._metrics["reward"].append(rewards.mean().item())
|
||||
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
|
||||
|
||||
return {
|
||||
"prompt_ids": prompt_ids,
|
||||
"prompt_mask": prompt_mask,
|
||||
"completion_ids": completion_ids,
|
||||
"completion_mask": completion_mask,
|
||||
"ref_per_token_logps": ref_per_token_logps,
|
||||
"advantages": advantages,
|
||||
"pixel_values": pixel_values,
|
||||
"image_grid_thw": image_grid_thw,
|
||||
}
|
||||
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
if return_outputs:
|
||||
raise ValueError("The GRPOTrainer does not support returning outputs")
|
||||
# Compute the per-token log probabilities for the model
|
||||
|
||||
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
||||
completion_ids, completion_mask = (
|
||||
inputs["completion_ids"],
|
||||
inputs["completion_mask"],
|
||||
)
|
||||
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
||||
pixel_values = inputs["pixel_values"]
|
||||
image_grid_thw = inputs["image_grid_thw"]
|
||||
logits_to_keep = completion_ids.size(
|
||||
1
|
||||
) # we only need to compute the logits for the completion tokens
|
||||
|
||||
per_token_logps = self._get_per_token_logps(
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
pixel_values,
|
||||
image_grid_thw,
|
||||
logits_to_keep,
|
||||
)
|
||||
|
||||
# Compute the KL divergence between the model and the reference model
|
||||
ref_per_token_logps = inputs["ref_per_token_logps"]
|
||||
per_token_kl = (
|
||||
torch.exp(ref_per_token_logps - per_token_logps)
|
||||
- (ref_per_token_logps - per_token_logps)
|
||||
- 1
|
||||
)
|
||||
|
||||
# x - x.detach() allows for preserving gradients from x
|
||||
advantages = inputs["advantages"]
|
||||
per_token_loss = torch.exp(
|
||||
per_token_logps - per_token_logps.detach()
|
||||
) * advantages.unsqueeze(1)
|
||||
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
|
||||
loss = (
|
||||
(per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)
|
||||
).mean()
|
||||
|
||||
# Log the metrics
|
||||
completion_length = (
|
||||
self.accelerator.gather_for_metrics(completion_mask.sum(1))
|
||||
.float()
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
self._metrics["completion_length"].append(completion_length)
|
||||
|
||||
mean_kl = (
|
||||
(per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)
|
||||
).mean()
|
||||
self._metrics["kl"].append(
|
||||
self.accelerator.gather_for_metrics(mean_kl).mean().item()
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
||||
|
||||
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
||||
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
||||
if next(iter(logs.keys())).startswith("eval_"):
|
||||
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
||||
|
||||
logs = {**logs, **metrics}
|
||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||
super().log(logs, start_time)
|
||||
else: # transformers<=4.46
|
||||
super().log(logs)
|
||||
self._metrics.clear()
|
||||
@@ -0,0 +1,86 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import subprocess
|
||||
from typing import List
|
||||
|
||||
from transformers import TrainerCallback
|
||||
from transformers.trainer_callback import TrainerControl, TrainerState
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
from .evaluation import run_benchmark_jobs
|
||||
from .hub import push_to_hub_revision
|
||||
|
||||
|
||||
def is_slurm_available() -> bool:
|
||||
# returns true if a slurm queueing system is available
|
||||
try:
|
||||
subprocess.run(["sinfo"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
class DummyConfig:
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class PushToHubRevisionCallback(TrainerCallback):
|
||||
def __init__(self, model_config) -> None:
|
||||
self.model_config = model_config
|
||||
|
||||
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
global_step = state.global_step
|
||||
|
||||
# WARNING: if you use dataclasses.replace(args, ...) the accelerator dist state will be broken, so I do this workaround
|
||||
# Also if you instantiate a new SFTConfig, the accelerator dist state will be broken
|
||||
dummy_config = DummyConfig(
|
||||
hub_model_id=args.hub_model_id,
|
||||
hub_model_revision=f"{args.hub_model_revision}-step-{global_step:09d}",
|
||||
output_dir=f"{args.output_dir}/checkpoint-{global_step}",
|
||||
system_prompt=args.system_prompt,
|
||||
)
|
||||
|
||||
future = push_to_hub_revision(
|
||||
dummy_config, extra_ignore_patterns=["*.pt"]
|
||||
) # don't push the optimizer states
|
||||
|
||||
if is_slurm_available():
|
||||
dummy_config.benchmarks = args.benchmarks
|
||||
|
||||
def run_benchmark_callback(_):
|
||||
print(f"Checkpoint {global_step} pushed to hub.")
|
||||
run_benchmark_jobs(dummy_config, self.model_config)
|
||||
|
||||
future.add_done_callback(run_benchmark_callback)
|
||||
|
||||
|
||||
CALLBACKS = {
|
||||
"push_to_hub_revision": PushToHubRevisionCallback,
|
||||
}
|
||||
|
||||
|
||||
def get_callbacks(train_config, model_config) -> List[TrainerCallback]:
|
||||
callbacks = []
|
||||
for callback_name in train_config.callbacks:
|
||||
if callback_name not in CALLBACKS:
|
||||
raise ValueError(f"Callback {callback_name} not found in CALLBACKS.")
|
||||
callbacks.append(CALLBACKS[callback_name](model_config))
|
||||
|
||||
return callbacks
|
||||
@@ -0,0 +1,105 @@
|
||||
import subprocess
|
||||
from typing import TYPE_CHECKING, Dict, Union
|
||||
|
||||
from .hub import get_gpu_count_for_vllm, get_param_count_from_repo_id
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trl import GRPOConfig, SFTConfig, ModelConfig
|
||||
|
||||
import os
|
||||
|
||||
|
||||
# We need a special environment setup to launch vLLM from within Slurm training jobs.
|
||||
# - Reference code: https://github.com/huggingface/brrr/blob/c55ba3505686d690de24c7ace6487a5c1426c0fd/brrr/lighteval/one_job_runner.py#L105
|
||||
# - Slack thread: https://huggingface.slack.com/archives/C043JTYE1MJ/p1726566494958269
|
||||
user_home_directory = os.path.expanduser("~")
|
||||
VLLM_SLURM_PREFIX = [
|
||||
"env",
|
||||
"-i",
|
||||
"bash",
|
||||
"-c",
|
||||
f"for f in /etc/profile.d/*.sh; do source $f; done; export HOME={user_home_directory}; sbatch ",
|
||||
]
|
||||
|
||||
|
||||
def register_lighteval_task(
|
||||
configs: Dict[str, str], eval_suite: str, task_name: str, task_list: str, num_fewshot: int = 0
|
||||
):
|
||||
"""Registers a LightEval task configuration.
|
||||
|
||||
- Core tasks can be added from this table: https://github.com/huggingface/lighteval/blob/main/src/lighteval/tasks/tasks_table.jsonl
|
||||
- Custom tasks that require their own metrics / scripts, should be stored in scripts/evaluation/extended_lighteval_tasks
|
||||
|
||||
Args:
|
||||
configs (Dict[str, str]): The dictionary to store the task configuration.
|
||||
eval_suite (str, optional): The evaluation suite.
|
||||
task_name (str): The name of the task.
|
||||
task_list (str): The comma-separated list of tasks in the format "extended|{task_name}|{num_fewshot}|0" or "lighteval|{task_name}|{num_fewshot}|0".
|
||||
num_fewshot (int, optional): The number of few-shot examples. Defaults to 0.
|
||||
is_custom_task (bool, optional): Whether the task is a custom task. Defaults to False.
|
||||
"""
|
||||
# Format task list in lighteval format
|
||||
task_list = ",".join(f"{eval_suite}|{task}|{num_fewshot}|0" for task in task_list.split(","))
|
||||
configs[task_name] = task_list
|
||||
|
||||
|
||||
LIGHTEVAL_TASKS = {}
|
||||
|
||||
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "math_500", "math_500", 0)
|
||||
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime24", "aime24", 0)
|
||||
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime25_part1", "aime25:part1", 0)
|
||||
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "gpqa", "gpqa:diamond", 0)
|
||||
|
||||
|
||||
def get_lighteval_tasks():
|
||||
return list(LIGHTEVAL_TASKS.keys())
|
||||
|
||||
|
||||
SUPPORTED_BENCHMARKS = get_lighteval_tasks()
|
||||
|
||||
|
||||
def run_lighteval_job(
|
||||
benchmark: str, training_args: Union["SFTConfig", "GRPOConfig"], model_args: "ModelConfig"
|
||||
) -> None:
|
||||
task_list = LIGHTEVAL_TASKS[benchmark]
|
||||
model_name = training_args.hub_model_id
|
||||
model_revision = training_args.hub_model_revision
|
||||
# For large models >= 30b params or those running the MATH benchmark, we need to shard them across the GPUs to avoid OOM
|
||||
num_gpus = get_gpu_count_for_vllm(model_name, model_revision)
|
||||
if get_param_count_from_repo_id(model_name) >= 30_000_000_000:
|
||||
tensor_parallel = True
|
||||
else:
|
||||
tensor_parallel = False
|
||||
|
||||
cmd = VLLM_SLURM_PREFIX.copy()
|
||||
cmd_args = [
|
||||
f"--gres=gpu:{num_gpus}",
|
||||
f"--job-name=or1_{benchmark}_{model_name.split('/')[-1]}_{model_revision}",
|
||||
"slurm/evaluate.slurm",
|
||||
benchmark,
|
||||
f'"{task_list}"',
|
||||
model_name,
|
||||
model_revision,
|
||||
f"{tensor_parallel}",
|
||||
f"{model_args.trust_remote_code}",
|
||||
]
|
||||
if training_args.system_prompt is not None:
|
||||
cmd_args.append(f"--system_prompt={training_args.system_prompt}")
|
||||
cmd[-1] += " " + " ".join(cmd_args)
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
def run_benchmark_jobs(training_args: Union["SFTConfig", "GRPOConfig"], model_args: "ModelConfig") -> None:
|
||||
benchmarks = training_args.benchmarks
|
||||
if len(benchmarks) == 1 and benchmarks[0] == "all":
|
||||
benchmarks = get_lighteval_tasks()
|
||||
# Evaluate on all supported benchmarks. Later we may want to include a `chat` option
|
||||
# that just evaluates on `ifeval` and `mt_bench` etc.
|
||||
|
||||
for benchmark in benchmarks:
|
||||
print(f"Launching benchmark `{benchmark}`")
|
||||
if benchmark in get_lighteval_tasks():
|
||||
run_lighteval_job(benchmark, training_args, model_args)
|
||||
else:
|
||||
raise ValueError(f"Unknown benchmark {benchmark}")
|
||||
@@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import re
|
||||
from concurrent.futures import Future
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
from huggingface_hub import (
|
||||
create_branch,
|
||||
create_repo,
|
||||
get_safetensors_metadata,
|
||||
list_repo_commits,
|
||||
list_repo_files,
|
||||
list_repo_refs,
|
||||
repo_exists,
|
||||
upload_folder,
|
||||
)
|
||||
from trl import GRPOConfig, SFTConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def push_to_hub_revision(training_args: SFTConfig | GRPOConfig, extra_ignore_patterns=[]) -> Future:
|
||||
"""Pushes the model to branch on a Hub repo."""
|
||||
|
||||
# Create a repo if it doesn't exist yet
|
||||
repo_url = create_repo(repo_id=training_args.hub_model_id, private=True, exist_ok=True)
|
||||
# Get initial commit to branch from
|
||||
initial_commit = list_repo_commits(training_args.hub_model_id)[-1]
|
||||
# Now create the branch we'll be pushing to
|
||||
create_branch(
|
||||
repo_id=training_args.hub_model_id,
|
||||
branch=training_args.hub_model_revision,
|
||||
revision=initial_commit.commit_id,
|
||||
exist_ok=True,
|
||||
)
|
||||
logger.info(f"Created target repo at {repo_url}")
|
||||
logger.info(f"Pushing to the Hub revision {training_args.hub_model_revision}...")
|
||||
ignore_patterns = ["checkpoint-*", "*.pth"]
|
||||
ignore_patterns.extend(extra_ignore_patterns)
|
||||
future = upload_folder(
|
||||
repo_id=training_args.hub_model_id,
|
||||
folder_path=training_args.output_dir,
|
||||
revision=training_args.hub_model_revision,
|
||||
commit_message=f"Add {training_args.hub_model_revision} checkpoint",
|
||||
ignore_patterns=ignore_patterns,
|
||||
run_as_future=True,
|
||||
)
|
||||
logger.info(f"Pushed to {repo_url} revision {training_args.hub_model_revision} successfully!")
|
||||
|
||||
return future
|
||||
|
||||
|
||||
def check_hub_revision_exists(training_args: SFTConfig | GRPOConfig):
|
||||
"""Checks if a given Hub revision exists."""
|
||||
if repo_exists(training_args.hub_model_id):
|
||||
if training_args.push_to_hub_revision is True:
|
||||
# First check if the revision exists
|
||||
revisions = [rev.name for rev in list_repo_refs(training_args.hub_model_id).branches]
|
||||
# If the revision exists, we next check it has a README file
|
||||
if training_args.hub_model_revision in revisions:
|
||||
repo_files = list_repo_files(
|
||||
repo_id=training_args.hub_model_id, revision=training_args.hub_model_revision
|
||||
)
|
||||
if "README.md" in repo_files and training_args.overwrite_hub_revision is False:
|
||||
raise ValueError(
|
||||
f"Revision {training_args.hub_model_revision} already exists. "
|
||||
"Use --overwrite_hub_revision to overwrite it."
|
||||
)
|
||||
|
||||
|
||||
def get_param_count_from_repo_id(repo_id: str) -> int:
|
||||
"""Function to get model param counts from safetensors metadata or find patterns like 42m, 1.5b, 0.5m or products like 8x7b in a repo ID."""
|
||||
try:
|
||||
metadata = get_safetensors_metadata(repo_id)
|
||||
return list(metadata.parameter_count.values())[0]
|
||||
except Exception:
|
||||
# Pattern to match products (like 8x7b) and single values (like 42m)
|
||||
pattern = r"((\d+(\.\d+)?)(x(\d+(\.\d+)?))?)([bm])"
|
||||
matches = re.findall(pattern, repo_id.lower())
|
||||
|
||||
param_counts = []
|
||||
for full_match, number1, _, _, number2, _, unit in matches:
|
||||
if number2: # If there's a second number, it's a product
|
||||
number = float(number1) * float(number2)
|
||||
else: # Otherwise, it's a single value
|
||||
number = float(number1)
|
||||
|
||||
if unit == "b":
|
||||
number *= 1_000_000_000 # Convert to billion
|
||||
elif unit == "m":
|
||||
number *= 1_000_000 # Convert to million
|
||||
|
||||
param_counts.append(number)
|
||||
|
||||
if len(param_counts) > 0:
|
||||
# Return the largest number
|
||||
return int(max(param_counts))
|
||||
else:
|
||||
# Return -1 if no match found
|
||||
return -1
|
||||
|
||||
|
||||
def get_gpu_count_for_vllm(model_name: str, revision: str = "main", num_gpus: int = 8) -> int:
|
||||
"""vLLM enforces a constraint that the number of attention heads must be divisible by the number of GPUs and 64 must be divisible by the number of GPUs.
|
||||
This function calculates the number of GPUs to use for decoding based on the number of attention heads in the model.
|
||||
"""
|
||||
config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)
|
||||
# Get number of attention heads
|
||||
num_heads = config.num_attention_heads
|
||||
# Reduce num_gpus so that num_heads is divisible by num_gpus and 64 is divisible by num_gpus
|
||||
while num_heads % num_gpus != 0 or 64 % num_gpus != 0:
|
||||
logger.info(f"Reducing num_gpus from {num_gpus} to {num_gpus - 1} to make num_heads divisible by num_gpus")
|
||||
num_gpus -= 1
|
||||
return num_gpus
|
||||
@@ -0,0 +1,220 @@
|
||||
from math_verify import parse, verify
|
||||
def compute_score(solution_str, ground_truth) -> float:
|
||||
retval = 0.
|
||||
|
||||
if solution_str == ground_truth:
|
||||
return 1.0
|
||||
|
||||
if float(verify(parse(solution_str), parse(ground_truth))) > 0:
|
||||
return 1.0
|
||||
|
||||
try:
|
||||
answer = solution_str
|
||||
string_in_last_boxed = last_boxed_only_string(solution_str)
|
||||
if string_in_last_boxed is not None:
|
||||
answer = remove_boxed(string_in_last_boxed)
|
||||
|
||||
if is_equiv(answer, ground_truth):
|
||||
return 1.0
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
return retval
|
||||
|
||||
|
||||
def remove_boxed(s):
|
||||
if "\\boxed " in s:
|
||||
left = "\\boxed "
|
||||
assert s[:len(left)] == left
|
||||
return s[len(left):]
|
||||
|
||||
left = "\\boxed{"
|
||||
|
||||
assert s[:len(left)] == left
|
||||
assert s[-1] == "}"
|
||||
|
||||
return s[len(left):-1]
|
||||
|
||||
def last_boxed_only_string(string):
|
||||
idx = string.rfind("\\boxed")
|
||||
if "\\boxed " in string:
|
||||
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
|
||||
if idx < 0:
|
||||
idx = string.rfind("\\fbox")
|
||||
if idx < 0:
|
||||
return None
|
||||
|
||||
i = idx
|
||||
right_brace_idx = None
|
||||
num_left_braces_open = 0
|
||||
while i < len(string):
|
||||
if string[i] == "{":
|
||||
num_left_braces_open += 1
|
||||
if string[i] == "}":
|
||||
num_left_braces_open -= 1
|
||||
if num_left_braces_open == 0:
|
||||
right_brace_idx = i
|
||||
break
|
||||
i += 1
|
||||
|
||||
if right_brace_idx is None:
|
||||
retval = None
|
||||
else:
|
||||
retval = string[idx:right_brace_idx + 1]
|
||||
|
||||
return retval
|
||||
|
||||
# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
|
||||
def is_equiv(str1, str2, verbose=False):
|
||||
if str1 is None and str2 is None:
|
||||
print("WARNING: Both None")
|
||||
return True
|
||||
if str1 is None or str2 is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
ss1 = strip_string(str1)
|
||||
ss2 = strip_string(str2)
|
||||
if verbose:
|
||||
print(ss1, ss2)
|
||||
return ss1 == ss2
|
||||
except Exception:
|
||||
return str1 == str2
|
||||
|
||||
|
||||
|
||||
def fix_fracs(string):
|
||||
substrs = string.split("\\frac")
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += "\\frac"
|
||||
if substr[0] == "{":
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except AssertionError:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != "{":
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}{" + b + "}"
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}" + b + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}" + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def fix_a_slash_b(string):
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
a = string.split("/")[0]
|
||||
b = string.split("/")[1]
|
||||
try:
|
||||
a = int(a)
|
||||
b = int(b)
|
||||
assert string == "{}/{}".format(a, b)
|
||||
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
||||
return new_string
|
||||
except AssertionError:
|
||||
return string
|
||||
|
||||
|
||||
def remove_right_units(string):
|
||||
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
||||
if "\\text{ " in string:
|
||||
splits = string.split("\\text{ ")
|
||||
assert len(splits) == 2
|
||||
return splits[0]
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def fix_sqrt(string):
|
||||
if "\\sqrt" not in string:
|
||||
return string
|
||||
splits = string.split("\\sqrt")
|
||||
new_string = splits[0]
|
||||
for split in splits[1:]:
|
||||
if split[0] != "{":
|
||||
a = split[0]
|
||||
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
||||
else:
|
||||
new_substr = "\\sqrt" + split
|
||||
new_string += new_substr
|
||||
return new_string
|
||||
|
||||
|
||||
def strip_string(string):
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
|
||||
# remove inverse spaces
|
||||
string = string.replace("\\!", "")
|
||||
|
||||
# replace \\ with \
|
||||
string = string.replace("\\\\", "\\")
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace("tfrac", "frac")
|
||||
string = string.replace("dfrac", "frac")
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\right", "")
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace("^{\\circ}", "")
|
||||
string = string.replace("^\\circ", "")
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace("\\$", "")
|
||||
|
||||
# remove units (on the right)
|
||||
string = remove_right_units(string)
|
||||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace("\%", "") # noqa: W605
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(" .", " 0.")
|
||||
string = string.replace("{.", "{0.")
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == ".":
|
||||
string = "0" + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
if len(string.split("=")) == 2:
|
||||
if len(string.split("=")[0]) <= 2:
|
||||
string = string.split("=")[1]
|
||||
|
||||
# fix sqrt3 --> sqrt{3}
|
||||
string = fix_sqrt(string)
|
||||
|
||||
# remove spaces
|
||||
string = string.replace(" ", "")
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||
string = fix_fracs(string)
|
||||
|
||||
# manually change 0.5 --> \frac{1}{2}
|
||||
if string == "0.5":
|
||||
string = "\\frac{1}{2}"
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = fix_a_slash_b(string)
|
||||
|
||||
return string
|
||||
@@ -0,0 +1,398 @@
|
||||
import json
|
||||
import time
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.collections import PatchCollection
|
||||
from matplotlib.patches import Polygon
|
||||
import numpy as np
|
||||
import copy
|
||||
import itertools
|
||||
#from . import mask as maskUtils
|
||||
import os
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
PYTHON_VERSION = sys.version_info[0]
|
||||
if PYTHON_VERSION == 2:
|
||||
from urllib import urlretrieve
|
||||
elif PYTHON_VERSION == 3:
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
|
||||
def _isArrayLike(obj):
|
||||
return hasattr(obj, '__iter__') and hasattr(obj, '__len__')
|
||||
|
||||
|
||||
class COCO:
|
||||
def __init__(self, annotation_file=None):
|
||||
"""
|
||||
Constructor of Microsoft COCO helper class for reading and visualizing annotations.
|
||||
:param annotation_file (str): location of annotation file
|
||||
:param image_folder (str): location to the folder that hosts images.
|
||||
:return:
|
||||
"""
|
||||
# load dataset
|
||||
self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
|
||||
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
|
||||
if not annotation_file == None:
|
||||
# print('loading annotations into memory...')
|
||||
tic = time.time()
|
||||
if type(annotation_file) == dict:
|
||||
dataset = annotation_file
|
||||
else:
|
||||
dataset = json.load(open(annotation_file, 'r'))
|
||||
assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
|
||||
# print('Done (t={:0.2f}s)'.format(time.time()- tic))
|
||||
self.dataset = dataset
|
||||
self.createIndex()
|
||||
|
||||
def createIndex(self):
|
||||
# create index
|
||||
# print('creating index...')
|
||||
anns, cats, imgs = {}, {}, {}
|
||||
imgToAnns,catToImgs = defaultdict(list),defaultdict(list)
|
||||
if 'annotations' in self.dataset:
|
||||
for ann in self.dataset['annotations']:
|
||||
imgToAnns[ann['image_id']].append(ann)
|
||||
anns[ann['id']] = ann
|
||||
|
||||
if 'images' in self.dataset:
|
||||
for img in self.dataset['images']:
|
||||
imgs[img['id']] = img
|
||||
|
||||
if 'categories' in self.dataset:
|
||||
for cat in self.dataset['categories']:
|
||||
cats[cat['id']] = cat
|
||||
|
||||
if 'annotations' in self.dataset and 'categories' in self.dataset:
|
||||
for ann in self.dataset['annotations']:
|
||||
catToImgs[ann['category_id']].append(ann['image_id'])
|
||||
|
||||
# print('index created!')
|
||||
|
||||
# create class members
|
||||
self.anns = anns
|
||||
self.imgToAnns = imgToAnns
|
||||
self.catToImgs = catToImgs
|
||||
self.imgs = imgs
|
||||
self.cats = cats
|
||||
|
||||
def info(self):
|
||||
"""
|
||||
Print information about the annotation file.
|
||||
:return:
|
||||
"""
|
||||
for key, value in self.dataset['info'].items():
|
||||
print('{}: {}'.format(key, value))
|
||||
|
||||
def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
|
||||
"""
|
||||
Get ann ids that satisfy given filter conditions. default skips that filter
|
||||
:param imgIds (int array) : get anns for given imgs
|
||||
catIds (int array) : get anns for given cats
|
||||
areaRng (float array) : get anns for given area range (e.g. [0 inf])
|
||||
iscrowd (boolean) : get anns for given crowd label (False or True)
|
||||
:return: ids (int array) : integer array of ann ids
|
||||
"""
|
||||
imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
|
||||
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
||||
|
||||
if len(imgIds) == len(catIds) == len(areaRng) == 0:
|
||||
anns = self.dataset['annotations']
|
||||
else:
|
||||
if not len(imgIds) == 0:
|
||||
lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns]
|
||||
anns = list(itertools.chain.from_iterable(lists))
|
||||
else:
|
||||
anns = self.dataset['annotations']
|
||||
anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds]
|
||||
anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]]
|
||||
if not iscrowd == None:
|
||||
ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
|
||||
else:
|
||||
ids = [ann['id'] for ann in anns]
|
||||
return ids
|
||||
|
||||
def getCatIds(self, catNms=[], supNms=[], catIds=[]):
|
||||
"""
|
||||
filtering parameters. default skips that filter.
|
||||
:param catNms (str array) : get cats for given cat names
|
||||
:param supNms (str array) : get cats for given supercategory names
|
||||
:param catIds (int array) : get cats for given cat ids
|
||||
:return: ids (int array) : integer array of cat ids
|
||||
"""
|
||||
catNms = catNms if _isArrayLike(catNms) else [catNms]
|
||||
supNms = supNms if _isArrayLike(supNms) else [supNms]
|
||||
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
||||
|
||||
if len(catNms) == len(supNms) == len(catIds) == 0:
|
||||
cats = self.dataset['categories']
|
||||
else:
|
||||
cats = self.dataset['categories']
|
||||
cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms]
|
||||
cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]
|
||||
cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds]
|
||||
ids = [cat['id'] for cat in cats]
|
||||
return ids
|
||||
|
||||
def getImgIds(self, imgIds=[], catIds=[]):
|
||||
'''
|
||||
Get img ids that satisfy given filter conditions.
|
||||
:param imgIds (int array) : get imgs for given ids
|
||||
:param catIds (int array) : get imgs with all given cats
|
||||
:return: ids (int array) : integer array of img ids
|
||||
'''
|
||||
imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
|
||||
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
||||
|
||||
if len(imgIds) == len(catIds) == 0:
|
||||
ids = self.imgs.keys()
|
||||
else:
|
||||
ids = set(imgIds)
|
||||
for i, catId in enumerate(catIds):
|
||||
if i == 0 and len(ids) == 0:
|
||||
ids = set(self.catToImgs[catId])
|
||||
else:
|
||||
ids &= set(self.catToImgs[catId])
|
||||
return list(ids)
|
||||
|
||||
def loadAnns(self, ids=[]):
|
||||
"""
|
||||
Load anns with the specified ids.
|
||||
:param ids (int array) : integer ids specifying anns
|
||||
:return: anns (object array) : loaded ann objects
|
||||
"""
|
||||
if _isArrayLike(ids):
|
||||
return [self.anns[id] for id in ids]
|
||||
elif type(ids) == int:
|
||||
return [self.anns[ids]]
|
||||
|
||||
def loadCats(self, ids=[]):
|
||||
"""
|
||||
Load cats with the specified ids.
|
||||
:param ids (int array) : integer ids specifying cats
|
||||
:return: cats (object array) : loaded cat objects
|
||||
"""
|
||||
if _isArrayLike(ids):
|
||||
return [self.cats[id] for id in ids]
|
||||
elif type(ids) == int:
|
||||
return [self.cats[ids]]
|
||||
|
||||
def loadImgs(self, ids=[]):
|
||||
"""
|
||||
Load anns with the specified ids.
|
||||
:param ids (int array) : integer ids specifying img
|
||||
:return: imgs (object array) : loaded img objects
|
||||
"""
|
||||
if _isArrayLike(ids):
|
||||
return [self.imgs[id] for id in ids]
|
||||
elif type(ids) == int:
|
||||
return [self.imgs[ids]]
|
||||
|
||||
def showAnns(self, anns, draw_bbox=False):
|
||||
"""
|
||||
Display the specified annotations.
|
||||
:param anns (array of object): annotations to display
|
||||
:return: None
|
||||
"""
|
||||
if len(anns) == 0:
|
||||
return 0
|
||||
if 'segmentation' in anns[0] or 'keypoints' in anns[0]:
|
||||
datasetType = 'instances'
|
||||
elif 'caption' in anns[0]:
|
||||
datasetType = 'captions'
|
||||
else:
|
||||
raise Exception('datasetType not supported')
|
||||
if datasetType == 'instances':
|
||||
ax = plt.gca()
|
||||
ax.set_autoscale_on(False)
|
||||
polygons = []
|
||||
color = []
|
||||
for ann in anns:
|
||||
c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
|
||||
if 'segmentation' in ann:
|
||||
if type(ann['segmentation']) == list:
|
||||
# polygon
|
||||
for seg in ann['segmentation']:
|
||||
poly = np.array(seg).reshape((int(len(seg)/2), 2))
|
||||
polygons.append(Polygon(poly))
|
||||
color.append(c)
|
||||
else:
|
||||
# mask
|
||||
t = self.imgs[ann['image_id']]
|
||||
if type(ann['segmentation']['counts']) == list:
|
||||
rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width'])
|
||||
else:
|
||||
rle = [ann['segmentation']]
|
||||
m = maskUtils.decode(rle)
|
||||
img = np.ones( (m.shape[0], m.shape[1], 3) )
|
||||
if ann['iscrowd'] == 1:
|
||||
color_mask = np.array([2.0,166.0,101.0])/255
|
||||
if ann['iscrowd'] == 0:
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
for i in range(3):
|
||||
img[:,:,i] = color_mask[i]
|
||||
ax.imshow(np.dstack( (img, m*0.5) ))
|
||||
if 'keypoints' in ann and type(ann['keypoints']) == list:
|
||||
# turn skeleton into zero-based index
|
||||
sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1
|
||||
kp = np.array(ann['keypoints'])
|
||||
x = kp[0::3]
|
||||
y = kp[1::3]
|
||||
v = kp[2::3]
|
||||
for sk in sks:
|
||||
if np.all(v[sk]>0):
|
||||
plt.plot(x[sk],y[sk], linewidth=3, color=c)
|
||||
plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2)
|
||||
plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2)
|
||||
|
||||
if draw_bbox:
|
||||
[bbox_x, bbox_y, bbox_w, bbox_h] = ann['bbox']
|
||||
poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
|
||||
np_poly = np.array(poly).reshape((4,2))
|
||||
polygons.append(Polygon(np_poly))
|
||||
color.append(c)
|
||||
|
||||
p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
|
||||
ax.add_collection(p)
|
||||
p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2)
|
||||
ax.add_collection(p)
|
||||
elif datasetType == 'captions':
|
||||
for ann in anns:
|
||||
print(ann['caption'])
|
||||
|
||||
def loadRes(self, resFile):
|
||||
"""
|
||||
Load result file and return a result api object.
|
||||
:param resFile (str) : file name of result file
|
||||
:return: res (obj) : result api object
|
||||
"""
|
||||
res = COCO()
|
||||
res.dataset['images'] = [img for img in self.dataset['images']]
|
||||
|
||||
# print('Loading and preparing results...')
|
||||
tic = time.time()
|
||||
if type(resFile) == str or (PYTHON_VERSION == 2 and type(resFile) == unicode):
|
||||
anns = json.load(open(resFile))
|
||||
elif type(resFile) == np.ndarray:
|
||||
anns = self.loadNumpyAnnotations(resFile)
|
||||
else:
|
||||
anns = resFile
|
||||
assert type(anns) == list, 'results in not an array of objects'
|
||||
annsImgIds = [ann['image_id'] for ann in anns]
|
||||
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
|
||||
'Results do not correspond to current coco set'
|
||||
if 'caption' in anns[0]:
|
||||
imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
|
||||
res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
|
||||
for id, ann in enumerate(anns):
|
||||
ann['id'] = id+1
|
||||
elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
|
||||
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
|
||||
for id, ann in enumerate(anns):
|
||||
bb = ann['bbox']
|
||||
x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]]
|
||||
if not 'segmentation' in ann:
|
||||
ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
|
||||
ann['area'] = bb[2]*bb[3]
|
||||
ann['id'] = id+1
|
||||
ann['iscrowd'] = 0
|
||||
elif 'segmentation' in anns[0]:
|
||||
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
|
||||
for id, ann in enumerate(anns):
|
||||
# now only support compressed RLE format as segmentation results
|
||||
ann['area'] = maskUtils.area(ann['segmentation'])
|
||||
if not 'bbox' in ann:
|
||||
ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
|
||||
ann['id'] = id+1
|
||||
ann['iscrowd'] = 0
|
||||
elif 'keypoints' in anns[0]:
|
||||
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
|
||||
for id, ann in enumerate(anns):
|
||||
s = ann['keypoints']
|
||||
x = s[0::3]
|
||||
y = s[1::3]
|
||||
x0,x1,y0,y1 = np.min(x), np.max(x), np.min(y), np.max(y)
|
||||
ann['area'] = (x1-x0)*(y1-y0)
|
||||
ann['id'] = id + 1
|
||||
ann['bbox'] = [x0,y0,x1-x0,y1-y0]
|
||||
# print('DONE (t={:0.2f}s)'.format(time.time()- tic))
|
||||
|
||||
res.dataset['annotations'] = anns
|
||||
res.createIndex()
|
||||
return res
|
||||
|
||||
def download(self, tarDir = None, imgIds = [] ):
|
||||
'''
|
||||
Download COCO images from mscoco.org server.
|
||||
:param tarDir (str): COCO results directory name
|
||||
imgIds (list): images to be downloaded
|
||||
:return:
|
||||
'''
|
||||
if tarDir is None:
|
||||
print('Please specify target directory')
|
||||
return -1
|
||||
if len(imgIds) == 0:
|
||||
imgs = self.imgs.values()
|
||||
else:
|
||||
imgs = self.loadImgs(imgIds)
|
||||
N = len(imgs)
|
||||
if not os.path.exists(tarDir):
|
||||
os.makedirs(tarDir)
|
||||
for i, img in enumerate(imgs):
|
||||
tic = time.time()
|
||||
fname = os.path.join(tarDir, img['file_name'])
|
||||
if not os.path.exists(fname):
|
||||
urlretrieve(img['coco_url'], fname)
|
||||
print('downloaded {}/{} images (t={:0.1f}s)'.format(i, N, time.time()- tic))
|
||||
|
||||
def loadNumpyAnnotations(self, data):
|
||||
"""
|
||||
Convert result data from a numpy array [Nx7] where each row contains {imageID,x1,y1,w,h,score,class}
|
||||
:param data (numpy.ndarray)
|
||||
:return: annotations (python nested list)
|
||||
"""
|
||||
print('Converting ndarray to lists...')
|
||||
assert(type(data) == np.ndarray)
|
||||
print(data.shape)
|
||||
assert(data.shape[1] == 7)
|
||||
N = data.shape[0]
|
||||
ann = []
|
||||
for i in range(N):
|
||||
if i % 1000000 == 0:
|
||||
print('{}/{}'.format(i,N))
|
||||
ann += [{
|
||||
'image_id' : int(data[i, 0]),
|
||||
'bbox' : [ data[i, 1], data[i, 2], data[i, 3], data[i, 4] ],
|
||||
'score' : data[i, 5],
|
||||
'category_id': int(data[i, 6]),
|
||||
}]
|
||||
return ann
|
||||
|
||||
def annToRLE(self, ann):
|
||||
"""
|
||||
Convert annotation which can be polygons, uncompressed RLE to RLE.
|
||||
:return: binary mask (numpy 2D array)
|
||||
"""
|
||||
t = self.imgs[ann['image_id']]
|
||||
h, w = t['height'], t['width']
|
||||
segm = ann['segmentation']
|
||||
if type(segm) == list:
|
||||
# polygon -- a single object might consist of multiple parts
|
||||
# we merge all parts into one mask rle code
|
||||
rles = maskUtils.frPyObjects(segm, h, w)
|
||||
rle = maskUtils.merge(rles)
|
||||
elif type(segm['counts']) == list:
|
||||
# uncompressed RLE
|
||||
rle = maskUtils.frPyObjects(segm, h, w)
|
||||
else:
|
||||
# rle
|
||||
rle = ann['segmentation']
|
||||
return rle
|
||||
|
||||
def annToMask(self, ann):
|
||||
"""
|
||||
Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
|
||||
:return: binary mask (numpy 2D array)
|
||||
"""
|
||||
rle = self.annToRLE(ann)
|
||||
m = maskUtils.decode(rle)
|
||||
return m
|
||||
@@ -0,0 +1,532 @@
|
||||
import numpy as np
|
||||
import datetime
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pycocotools import mask as maskUtils
|
||||
import copy
|
||||
|
||||
class COCOeval:
|
||||
# Interface for evaluating detection on the Microsoft COCO dataset.
|
||||
#
|
||||
# The usage for CocoEval is as follows:
|
||||
# cocoGt=..., cocoDt=... # load dataset and results
|
||||
# E = CocoEval(cocoGt,cocoDt); # initialize CocoEval object
|
||||
# E.params.recThrs = ...; # set parameters as desired
|
||||
# E.evaluate(); # run per image evaluation
|
||||
# E.accumulate(); # accumulate per image results
|
||||
# E.summarize(); # display summary metrics of results
|
||||
# For example usage see evalDemo.m and http://mscoco.org/.
|
||||
#
|
||||
# The evaluation parameters are as follows (defaults in brackets):
|
||||
# imgIds - [all] N img ids to use for evaluation
|
||||
# catIds - [all] K cat ids to use for evaluation
|
||||
# iouThrs - [.5:.05:.95] T=10 IoU thresholds for evaluation
|
||||
# recThrs - [0:.01:1] R=101 recall thresholds for evaluation
|
||||
# areaRng - [...] A=4 object area ranges for evaluation
|
||||
# maxDets - [1 10 100] M=3 thresholds on max detections per image
|
||||
# iouType - ['segm'] set iouType to 'segm', 'bbox' or 'keypoints'
|
||||
# iouType replaced the now DEPRECATED useSegm parameter.
|
||||
# useCats - [1] if true use category labels for evaluation
|
||||
# Note: if useCats=0 category labels are ignored as in proposal scoring.
|
||||
# Note: multiple areaRngs [Ax2] and maxDets [Mx1] can be specified.
|
||||
#
|
||||
# evaluate(): evaluates detections on every image and every category and
|
||||
# concats the results into the "evalImgs" with fields:
|
||||
# dtIds - [1xD] id for each of the D detections (dt)
|
||||
# gtIds - [1xG] id for each of the G ground truths (gt)
|
||||
# dtMatches - [TxD] matching gt id at each IoU or 0
|
||||
# gtMatches - [TxG] matching dt id at each IoU or 0
|
||||
# dtScores - [1xD] confidence of each dt
|
||||
# gtIgnore - [1xG] ignore flag for each gt
|
||||
# dtIgnore - [TxD] ignore flag for each dt at each IoU
|
||||
#
|
||||
# accumulate(): accumulates the per-image, per-category evaluation
|
||||
# results in "evalImgs" into the dictionary "eval" with fields:
|
||||
# params - parameters used for evaluation
|
||||
# date - date evaluation was performed
|
||||
# counts - [T,R,K,A,M] parameter dimensions (see above)
|
||||
# precision - [TxRxKxAxM] precision for every evaluation setting
|
||||
# recall - [TxKxAxM] max recall for every evaluation setting
|
||||
# Note: precision and recall==-1 for settings with no gt objects.
|
||||
#
|
||||
# See also coco, mask, pycocoDemo, pycocoEvalDemo
|
||||
#
|
||||
# Microsoft COCO Toolbox. version 2.0
|
||||
# Data, paper, and tutorials available at: http://mscoco.org/
|
||||
# Code written by Piotr Dollar and Tsung-Yi Lin, 2015.
|
||||
# Licensed under the Simplified BSD License [see coco/license.txt]
|
||||
def __init__(self, cocoGt=None, cocoDt=None, iouType='segm'):
|
||||
'''
|
||||
Initialize CocoEval using coco APIs for gt and dt
|
||||
:param cocoGt: coco object with ground truth annotations
|
||||
:param cocoDt: coco object with detection results
|
||||
:return: None
|
||||
'''
|
||||
if not iouType:
|
||||
print('iouType not specified. use default iouType segm')
|
||||
self.cocoGt = cocoGt # ground truth COCO API
|
||||
self.cocoDt = cocoDt # detections COCO API
|
||||
self.evalImgs = defaultdict(list) # per-image per-category evaluation results [KxAxI] elements
|
||||
self.eval = {} # accumulated evaluation results
|
||||
self._gts = defaultdict(list) # gt for evaluation
|
||||
self._dts = defaultdict(list) # dt for evaluation
|
||||
self.params = Params(iouType=iouType) # parameters
|
||||
self._paramsEval = {} # parameters for evaluation
|
||||
self.stats = [] # result summarization
|
||||
self.ious = {} # ious between all gts and dts
|
||||
if not cocoGt is None:
|
||||
self.params.imgIds = sorted(cocoGt.getImgIds())
|
||||
self.params.catIds = sorted(cocoGt.getCatIds())
|
||||
|
||||
|
||||
def _prepare(self):
|
||||
'''
|
||||
Prepare ._gts and ._dts for evaluation based on params
|
||||
:return: None
|
||||
'''
|
||||
def _toMask(anns, coco):
|
||||
# modify ann['segmentation'] by reference
|
||||
for ann in anns:
|
||||
rle = coco.annToRLE(ann)
|
||||
ann['segmentation'] = rle
|
||||
p = self.params
|
||||
if p.useCats:
|
||||
gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
|
||||
dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
|
||||
else:
|
||||
gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
|
||||
dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
|
||||
|
||||
# convert ground truth to mask if iouType == 'segm'
|
||||
if p.iouType == 'segm':
|
||||
_toMask(gts, self.cocoGt)
|
||||
_toMask(dts, self.cocoDt)
|
||||
# set ignore flag
|
||||
for gt in gts:
|
||||
gt['ignore'] = gt['ignore'] if 'ignore' in gt else 0
|
||||
gt['ignore'] = 'iscrowd' in gt and gt['iscrowd']
|
||||
if p.iouType == 'keypoints':
|
||||
gt['ignore'] = (gt['num_keypoints'] == 0) or gt['ignore']
|
||||
self._gts = defaultdict(list) # gt for evaluation
|
||||
self._dts = defaultdict(list) # dt for evaluation
|
||||
for gt in gts:
|
||||
self._gts[gt['image_id'], gt['category_id']].append(gt)
|
||||
for dt in dts:
|
||||
self._dts[dt['image_id'], dt['category_id']].append(dt)
|
||||
self.evalImgs = defaultdict(list) # per-image per-category evaluation results
|
||||
self.eval = {} # accumulated evaluation results
|
||||
|
||||
def evaluate(self):
|
||||
'''
|
||||
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
|
||||
:return: None
|
||||
'''
|
||||
tic = time.time()
|
||||
#('Running per image evaluation...')
|
||||
p = self.params
|
||||
# add backward compatibility if useSegm is specified in params
|
||||
if not p.useSegm is None:
|
||||
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
|
||||
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
|
||||
# print('Evaluate annotation type *{}*'.format(p.iouType))
|
||||
p.imgIds = list(np.unique(p.imgIds))
|
||||
if p.useCats:
|
||||
p.catIds = list(np.unique(p.catIds))
|
||||
p.maxDets = sorted(p.maxDets)
|
||||
self.params=p
|
||||
|
||||
self._prepare()
|
||||
# loop through images, area range, max detection number
|
||||
catIds = p.catIds if p.useCats else [-1]
|
||||
|
||||
if p.iouType == 'segm' or p.iouType == 'bbox':
|
||||
computeIoU = self.computeIoU
|
||||
elif p.iouType == 'keypoints':
|
||||
computeIoU = self.computeOks
|
||||
self.ious = {(imgId, catId): computeIoU(imgId, catId) \
|
||||
for imgId in p.imgIds
|
||||
for catId in catIds}
|
||||
|
||||
evaluateImg = self.evaluateImg
|
||||
maxDet = p.maxDets[-1]
|
||||
self.evalImgs = [evaluateImg(imgId, catId, areaRng, maxDet)
|
||||
for catId in catIds
|
||||
for areaRng in p.areaRng
|
||||
for imgId in p.imgIds
|
||||
]
|
||||
self._paramsEval = copy.deepcopy(self.params)
|
||||
toc = time.time()
|
||||
#print('DONE (t={:0.2f}s).'.format(toc-tic))
|
||||
|
||||
def computeIoU(self, imgId, catId):
|
||||
p = self.params
|
||||
if p.useCats:
|
||||
gt = self._gts[imgId,catId]
|
||||
dt = self._dts[imgId,catId]
|
||||
else:
|
||||
gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]]
|
||||
dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]]
|
||||
if len(gt) == 0 and len(dt) ==0:
|
||||
return []
|
||||
inds = np.argsort([-d['score'] for d in dt], kind='mergesort')
|
||||
dt = [dt[i] for i in inds]
|
||||
if len(dt) > p.maxDets[-1]:
|
||||
dt=dt[0:p.maxDets[-1]]
|
||||
|
||||
if p.iouType == 'segm':
|
||||
g = [g['segmentation'] for g in gt]
|
||||
d = [d['segmentation'] for d in dt]
|
||||
elif p.iouType == 'bbox':
|
||||
g = [g['bbox'] for g in gt]
|
||||
d = [d['bbox'] for d in dt]
|
||||
else:
|
||||
raise Exception('unknown iouType for iou computation')
|
||||
|
||||
# compute iou between each dt and gt region
|
||||
iscrowd = [int(o['iscrowd']) for o in gt]
|
||||
ious = maskUtils.iou(d,g,iscrowd)
|
||||
return ious
|
||||
|
||||
def computeOks(self, imgId, catId):
|
||||
p = self.params
|
||||
# dimention here should be Nxm
|
||||
gts = self._gts[imgId, catId]
|
||||
dts = self._dts[imgId, catId]
|
||||
inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
|
||||
dts = [dts[i] for i in inds]
|
||||
if len(dts) > p.maxDets[-1]:
|
||||
dts = dts[0:p.maxDets[-1]]
|
||||
# if len(gts) == 0 and len(dts) == 0:
|
||||
if len(gts) == 0 or len(dts) == 0:
|
||||
return []
|
||||
ious = np.zeros((len(dts), len(gts)))
|
||||
sigmas = p.kpt_oks_sigmas
|
||||
vars = (sigmas * 2)**2
|
||||
k = len(sigmas)
|
||||
# compute oks between each detection and ground truth object
|
||||
for j, gt in enumerate(gts):
|
||||
# create bounds for ignore regions(double the gt bbox)
|
||||
g = np.array(gt['keypoints'])
|
||||
xg = g[0::3]; yg = g[1::3]; vg = g[2::3]
|
||||
k1 = np.count_nonzero(vg > 0)
|
||||
bb = gt['bbox']
|
||||
x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2
|
||||
y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2
|
||||
for i, dt in enumerate(dts):
|
||||
d = np.array(dt['keypoints'])
|
||||
xd = d[0::3]; yd = d[1::3]
|
||||
if k1>0:
|
||||
# measure the per-keypoint distance if keypoints visible
|
||||
dx = xd - xg
|
||||
dy = yd - yg
|
||||
else:
|
||||
# measure minimum distance to keypoints in (x0,y0) & (x1,y1)
|
||||
z = np.zeros((k))
|
||||
dx = np.max((z, x0-xd),axis=0)+np.max((z, xd-x1),axis=0)
|
||||
dy = np.max((z, y0-yd),axis=0)+np.max((z, yd-y1),axis=0)
|
||||
e = (dx**2 + dy**2) / vars / (gt['area']+np.spacing(1)) / 2
|
||||
if k1 > 0:
|
||||
e=e[vg > 0]
|
||||
ious[i, j] = np.sum(np.exp(-e)) / e.shape[0]
|
||||
return ious
|
||||
|
||||
def evaluateImg(self, imgId, catId, aRng, maxDet):
|
||||
'''
|
||||
perform evaluation for single category and image
|
||||
:return: dict (single image results)
|
||||
'''
|
||||
p = self.params
|
||||
if p.useCats:
|
||||
gt = self._gts[imgId,catId]
|
||||
dt = self._dts[imgId,catId]
|
||||
else:
|
||||
gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]]
|
||||
dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]]
|
||||
if len(gt) == 0 and len(dt) ==0:
|
||||
return None
|
||||
|
||||
for g in gt:
|
||||
if g['ignore'] or (g['area']<aRng[0] or g['area']>aRng[1]):
|
||||
g['_ignore'] = 1
|
||||
else:
|
||||
g['_ignore'] = 0
|
||||
|
||||
# sort dt highest score first, sort gt ignore last
|
||||
gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort')
|
||||
gt = [gt[i] for i in gtind]
|
||||
dtind = np.argsort([-d['score'] for d in dt], kind='mergesort')
|
||||
dt = [dt[i] for i in dtind[0:maxDet]]
|
||||
iscrowd = [int(o['iscrowd']) for o in gt]
|
||||
# load computed ious
|
||||
ious = self.ious[imgId, catId][:, gtind] if len(self.ious[imgId, catId]) > 0 else self.ious[imgId, catId]
|
||||
|
||||
T = len(p.iouThrs)
|
||||
G = len(gt)
|
||||
D = len(dt)
|
||||
gtm = np.zeros((T,G))
|
||||
dtm = np.zeros((T,D))
|
||||
gtIg = np.array([g['_ignore'] for g in gt])
|
||||
dtIg = np.zeros((T,D))
|
||||
if not len(ious)==0:
|
||||
for tind, t in enumerate(p.iouThrs):
|
||||
for dind, d in enumerate(dt):
|
||||
# information about best match so far (m=-1 -> unmatched)
|
||||
iou = min([t,1-1e-10])
|
||||
m = -1
|
||||
for gind, g in enumerate(gt):
|
||||
# if this gt already matched, and not a crowd, continue
|
||||
if gtm[tind,gind]>0 and not iscrowd[gind]:
|
||||
continue
|
||||
# if dt matched to reg gt, and on ignore gt, stop
|
||||
if m>-1 and gtIg[m]==0 and gtIg[gind]==1:
|
||||
break
|
||||
# continue to next gt unless better match made
|
||||
if ious[dind,gind] < iou:
|
||||
continue
|
||||
# if match successful and best so far, store appropriately
|
||||
iou=ious[dind,gind]
|
||||
m=gind
|
||||
# if match made store id of match for both dt and gt
|
||||
if m ==-1:
|
||||
continue
|
||||
dtIg[tind,dind] = gtIg[m]
|
||||
dtm[tind,dind] = gt[m]['id']
|
||||
gtm[tind,m] = d['id']
|
||||
# set unmatched detections outside of area range to ignore
|
||||
a = np.array([d['area']<aRng[0] or d['area']>aRng[1] for d in dt]).reshape((1, len(dt)))
|
||||
dtIg = np.logical_or(dtIg, np.logical_and(dtm==0, np.repeat(a,T,0)))
|
||||
# store results for given image and category
|
||||
return {
|
||||
'image_id': imgId,
|
||||
'category_id': catId,
|
||||
'aRng': aRng,
|
||||
'maxDet': maxDet,
|
||||
'dtIds': [d['id'] for d in dt],
|
||||
'gtIds': [g['id'] for g in gt],
|
||||
'dtMatches': dtm,
|
||||
'gtMatches': gtm,
|
||||
'dtScores': [d['score'] for d in dt],
|
||||
'gtIgnore': gtIg,
|
||||
'dtIgnore': dtIg,
|
||||
}
|
||||
|
||||
def accumulate(self, p = None):
|
||||
'''
|
||||
Accumulate per image evaluation results and store the result in self.eval
|
||||
:param p: input params for evaluation
|
||||
:return: None
|
||||
'''
|
||||
#print('Accumulating evaluation results...')
|
||||
tic = time.time()
|
||||
if not self.evalImgs:
|
||||
print('Please run evaluate() first')
|
||||
# allows input customized parameters
|
||||
if p is None:
|
||||
p = self.params
|
||||
p.catIds = p.catIds if p.useCats == 1 else [-1]
|
||||
T = len(p.iouThrs)
|
||||
R = len(p.recThrs)
|
||||
K = len(p.catIds) if p.useCats else 1
|
||||
A = len(p.areaRng)
|
||||
M = len(p.maxDets)
|
||||
precision = -np.ones((T,R,K,A,M)) # -1 for the precision of absent categories
|
||||
recall = -np.ones((T,K,A,M))
|
||||
scores = -np.ones((T,R,K,A,M))
|
||||
|
||||
# create dictionary for future indexing
|
||||
_pe = self._paramsEval
|
||||
catIds = _pe.catIds if _pe.useCats else [-1]
|
||||
setK = set(catIds)
|
||||
setA = set(map(tuple, _pe.areaRng))
|
||||
setM = set(_pe.maxDets)
|
||||
setI = set(_pe.imgIds)
|
||||
# get inds to evaluate
|
||||
k_list = [n for n, k in enumerate(p.catIds) if k in setK]
|
||||
m_list = [m for n, m in enumerate(p.maxDets) if m in setM]
|
||||
a_list = [n for n, a in enumerate(map(lambda x: tuple(x), p.areaRng)) if a in setA]
|
||||
i_list = [n for n, i in enumerate(p.imgIds) if i in setI]
|
||||
I0 = len(_pe.imgIds)
|
||||
A0 = len(_pe.areaRng)
|
||||
# retrieve E at each category, area range, and max number of detections
|
||||
for k, k0 in enumerate(k_list):
|
||||
Nk = k0*A0*I0
|
||||
for a, a0 in enumerate(a_list):
|
||||
Na = a0*I0
|
||||
for m, maxDet in enumerate(m_list):
|
||||
E = [self.evalImgs[Nk + Na + i] for i in i_list]
|
||||
E = [e for e in E if not e is None]
|
||||
if len(E) == 0:
|
||||
continue
|
||||
dtScores = np.concatenate([e['dtScores'][0:maxDet] for e in E])
|
||||
|
||||
# different sorting method generates slightly different results.
|
||||
# mergesort is used to be consistent as Matlab implementation.
|
||||
inds = np.argsort(-dtScores, kind='mergesort')
|
||||
dtScoresSorted = dtScores[inds]
|
||||
|
||||
dtm = np.concatenate([e['dtMatches'][:,0:maxDet] for e in E], axis=1)[:,inds]
|
||||
dtIg = np.concatenate([e['dtIgnore'][:,0:maxDet] for e in E], axis=1)[:,inds]
|
||||
gtIg = np.concatenate([e['gtIgnore'] for e in E])
|
||||
npig = np.count_nonzero(gtIg==0 )
|
||||
if npig == 0:
|
||||
continue
|
||||
tps = np.logical_and( dtm, np.logical_not(dtIg) )
|
||||
fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg) )
|
||||
|
||||
tp_sum = np.cumsum(tps, axis=1).astype(dtype=float)
|
||||
fp_sum = np.cumsum(fps, axis=1).astype(dtype=float)
|
||||
for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
|
||||
tp = np.array(tp)
|
||||
fp = np.array(fp)
|
||||
nd = len(tp)
|
||||
rc = tp / npig
|
||||
pr = tp / (fp+tp+np.spacing(1))
|
||||
q = np.zeros((R,))
|
||||
ss = np.zeros((R,))
|
||||
|
||||
if nd:
|
||||
recall[t,k,a,m] = rc[-1]
|
||||
else:
|
||||
recall[t,k,a,m] = 0
|
||||
|
||||
# numpy is slow without cython optimization for accessing elements
|
||||
# use python array gets significant speed improvement
|
||||
pr = pr.tolist(); q = q.tolist()
|
||||
|
||||
for i in range(nd-1, 0, -1):
|
||||
if pr[i] > pr[i-1]:
|
||||
pr[i-1] = pr[i]
|
||||
|
||||
inds = np.searchsorted(rc, p.recThrs, side='left')
|
||||
try:
|
||||
for ri, pi in enumerate(inds):
|
||||
q[ri] = pr[pi]
|
||||
ss[ri] = dtScoresSorted[pi]
|
||||
except:
|
||||
pass
|
||||
precision[t,:,k,a,m] = np.array(q)
|
||||
scores[t,:,k,a,m] = np.array(ss)
|
||||
self.eval = {
|
||||
'params': p,
|
||||
'counts': [T, R, K, A, M],
|
||||
'date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
'scores': scores,
|
||||
}
|
||||
toc = time.time()
|
||||
# print('DONE (t={:0.2f}s).'.format( toc-tic))
|
||||
|
||||
def summarize(self):
|
||||
'''
|
||||
Compute and display summary metrics for evaluation results.
|
||||
Note this functin can *only* be applied on the default parameter setting
|
||||
'''
|
||||
def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ):
|
||||
p = self.params
|
||||
iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'
|
||||
titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
|
||||
typeStr = '(AP)' if ap==1 else '(AR)'
|
||||
iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
|
||||
if iouThr is None else '{:0.2f}'.format(iouThr)
|
||||
|
||||
aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
|
||||
mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
|
||||
if ap == 1:
|
||||
# dimension of precision: [TxRxKxAxM]
|
||||
s = self.eval['precision']
|
||||
# IoU
|
||||
if iouThr is not None:
|
||||
t = np.where(iouThr == p.iouThrs)[0]
|
||||
s = s[t]
|
||||
s = s[:,:,:,aind,mind]
|
||||
else:
|
||||
# dimension of recall: [TxKxAxM]
|
||||
s = self.eval['recall']
|
||||
if iouThr is not None:
|
||||
t = np.where(iouThr == p.iouThrs)[0]
|
||||
s = s[t]
|
||||
s = s[:,:,aind,mind]
|
||||
if len(s[s>-1])==0:
|
||||
mean_s = -1
|
||||
else:
|
||||
mean_s = np.mean(s[s>-1])
|
||||
#print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
|
||||
return mean_s
|
||||
def _summarizeDets():
|
||||
stats = np.zeros((12,))
|
||||
stats[0] = _summarize(1)
|
||||
stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])
|
||||
stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2])
|
||||
stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])
|
||||
stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2])
|
||||
stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2])
|
||||
stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
|
||||
stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
|
||||
stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
|
||||
stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2])
|
||||
stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])
|
||||
stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])
|
||||
return stats
|
||||
def _summarizeKps():
|
||||
stats = np.zeros((10,))
|
||||
stats[0] = _summarize(1, maxDets=20)
|
||||
stats[1] = _summarize(1, maxDets=20, iouThr=.5)
|
||||
stats[2] = _summarize(1, maxDets=20, iouThr=.75)
|
||||
stats[3] = _summarize(1, maxDets=20, areaRng='medium')
|
||||
stats[4] = _summarize(1, maxDets=20, areaRng='large')
|
||||
stats[5] = _summarize(0, maxDets=20)
|
||||
stats[6] = _summarize(0, maxDets=20, iouThr=.5)
|
||||
stats[7] = _summarize(0, maxDets=20, iouThr=.75)
|
||||
stats[8] = _summarize(0, maxDets=20, areaRng='medium')
|
||||
stats[9] = _summarize(0, maxDets=20, areaRng='large')
|
||||
return stats
|
||||
if not self.eval:
|
||||
raise Exception('Please run accumulate() first')
|
||||
iouType = self.params.iouType
|
||||
if iouType == 'segm' or iouType == 'bbox':
|
||||
summarize = _summarizeDets
|
||||
elif iouType == 'keypoints':
|
||||
summarize = _summarizeKps
|
||||
self.stats = summarize()
|
||||
|
||||
def __str__(self):
|
||||
self.summarize()
|
||||
|
||||
class Params:
|
||||
'''
|
||||
Params for coco evaluation api
|
||||
'''
|
||||
def setDetParams(self):
|
||||
self.imgIds = []
|
||||
self.catIds = []
|
||||
# np.arange causes trouble. the data point on arange is slightly larger than the true value
|
||||
self.iouThrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
|
||||
self.recThrs = np.linspace(.0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True)
|
||||
self.maxDets = [1, 10, 100]
|
||||
self.areaRng = [[0 ** 2, 1e5 ** 2], [0 ** 2, 32 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]]
|
||||
self.areaRngLbl = ['all', 'small', 'medium', 'large']
|
||||
self.useCats = 1
|
||||
|
||||
def setKpParams(self):
|
||||
self.imgIds = []
|
||||
self.catIds = []
|
||||
# np.arange causes trouble. the data point on arange is slightly larger than the true value
|
||||
self.iouThrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
|
||||
self.recThrs = np.linspace(.0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True)
|
||||
self.maxDets = [20]
|
||||
self.areaRng = [[0 ** 2, 1e5 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]]
|
||||
self.areaRngLbl = ['all', 'medium', 'large']
|
||||
self.useCats = 1
|
||||
self.kpt_oks_sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0
|
||||
|
||||
def __init__(self, iouType='segm'):
|
||||
if iouType == 'segm' or iouType == 'bbox':
|
||||
self.setDetParams()
|
||||
elif iouType == 'keypoints':
|
||||
self.setKpParams()
|
||||
else:
|
||||
raise Exception('iouType not supported')
|
||||
self.iouType = iouType
|
||||
# useSegm is deprecated
|
||||
self.useSegm = None
|
||||
@@ -0,0 +1,5 @@
|
||||
from .vlm_module import VLMBaseModule
|
||||
from .qwen_module import Qwen2VLModule
|
||||
from .internvl_module import InvernVLModule
|
||||
|
||||
__all__ = ["VLMBaseModule", "Qwen2VLModule", "InvernVLModule"]
|
||||
@@ -0,0 +1,328 @@
|
||||
from open_r1.vlm_modules.vlm_module import VLMBaseModule
|
||||
from typing import Dict, Any, Union
|
||||
from transformers import AutoModel, AutoProcessor, AutoConfig
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from transformers.feature_extraction_sequence_utils import BatchFeature
|
||||
|
||||
IMG_START_TOKEN='<img>'
|
||||
IMG_END_TOKEN='</img>'
|
||||
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'
|
||||
|
||||
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
class InvernVLModule(VLMBaseModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv_template = None
|
||||
self.num_image_token = None
|
||||
|
||||
def get_vlm_key(self):
|
||||
return "internvl"
|
||||
|
||||
def get_model_class(self, model_id: str, model_init_kwargs: dict):
|
||||
assert "InternVL" in model_id, f"model_id must contain 'InternVL', but got {model_id}"
|
||||
self.model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
|
||||
# The model class of InternVL when being mapped has been determined by its config
|
||||
model_cls = AutoModel
|
||||
# InternVL should be inputted with "trust_remote_code=True"
|
||||
model_init_kwargs["trust_remote_code"] = True
|
||||
# "use_cache" should be removed
|
||||
model_init_kwargs.pop("use_cache", None)
|
||||
# "flash_attention_2" should be modified to "use_flash_attn" in InternVL
|
||||
if "flash_attention_2" in model_init_kwargs.get("attn_implementation", ""):
|
||||
model_init_kwargs["use_flash_attn"] = True
|
||||
model_init_kwargs.pop("attn_implementation")
|
||||
return model_cls
|
||||
|
||||
def post_model_init(self, model, processing_class):
|
||||
self.conv_template = model.conv_template if self.conv_template is None else self.conv_template
|
||||
self.num_image_token = model.num_image_token if self.num_image_token is None else self.num_image_token
|
||||
img_context_token_id = processing_class.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
||||
model.img_context_token_id = img_context_token_id
|
||||
|
||||
def is_embeds_input(self):
|
||||
return True
|
||||
|
||||
def get_processing_class(self):
|
||||
return AutoProcessor
|
||||
|
||||
def get_eos_token_id(self, processing_class):
|
||||
eos_token_id = processing_class.convert_tokens_to_ids(self.conv_template.sep.strip())
|
||||
return eos_token_id
|
||||
|
||||
def get_vision_modules_keywords(self):
|
||||
return ['vision_model']
|
||||
|
||||
def get_custom_multimodal_keywords(self):
|
||||
return ['pixel_values', 'image_flags']
|
||||
|
||||
def get_non_generate_params(self):
|
||||
return ['image_flags']
|
||||
|
||||
def get_custom_processing_keywords(self):
|
||||
return [('None', 'max_anyres_num')]
|
||||
|
||||
def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
|
||||
prompts_text = []
|
||||
for example in inputs:
|
||||
template = self.conv_template.copy()
|
||||
conversation_list = example["prompt"]
|
||||
system_message = extract_system_message(conversation_list)
|
||||
if system_message is not None:
|
||||
template.system_message = system_message
|
||||
|
||||
processed_list = process_conversation_list(conversation_list, system_message)
|
||||
for i, processed_item in enumerate(processed_list):
|
||||
if i % 2 == 0:
|
||||
template.append_message(template.roles[0], processed_item)
|
||||
else:
|
||||
template.append_message(template.roles[1], processed_item)
|
||||
if len(processed_list) % 2 == 1:
|
||||
template.append_message(template.roles[1], None)
|
||||
query = template.get_prompt()
|
||||
prompts_text.append(query)
|
||||
return prompts_text
|
||||
|
||||
def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False):
|
||||
# Process images
|
||||
full_pixel_values = []
|
||||
num_patches_list = []
|
||||
for img in images:
|
||||
pixel_values = self._load_image(img, input_size=self.model_config.vision_config.image_size, max_num=processing_class.max_anyres_num)
|
||||
full_pixel_values.append(pixel_values)
|
||||
num_patches_list.append(pixel_values.shape[0])
|
||||
full_pixel_values = torch.cat(full_pixel_values, dim=0)
|
||||
|
||||
# Process prompts
|
||||
queries = []
|
||||
image_idx = 0
|
||||
for query in prompts_text:
|
||||
while "<image>" in query:
|
||||
num_patches = num_patches_list[image_idx]
|
||||
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
|
||||
query = query.replace("<image>", image_tokens, 1)
|
||||
image_idx += 1
|
||||
queries.append(query)
|
||||
assert image_idx == len(num_patches_list)
|
||||
|
||||
model_inputs = processing_class(
|
||||
queries,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
padding_side=padding_side,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
model_inputs["pixel_values"] = full_pixel_values
|
||||
# Only support pure-image data currently (each sample should contain the image)
|
||||
model_inputs['image_flags'] = torch.ones(full_pixel_values.shape[0], dtype=torch.long)
|
||||
|
||||
model_inputs = BatchFeature(data=model_inputs)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def _load_image(self, image: Image.Image, input_size: int=448, max_num:int=12):
|
||||
transform = build_transform(input_size=input_size)
|
||||
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
||||
pixel_values = [transform(image) for image in images]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
return pixel_values
|
||||
|
||||
@staticmethod
|
||||
def get_question_template(task_type: str):
|
||||
match task_type:
|
||||
case _:
|
||||
return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags."
|
||||
|
||||
@staticmethod
|
||||
def format_reward_rec(completions, **kwargs):
|
||||
"""Check if the InternVL model output matches a specific format."""
|
||||
import re
|
||||
import os
|
||||
from datetime import datetime
|
||||
pattern = r"<think>.*?</think>\s*<answer>.*?\[\d+,\s*\d+,\s*\d+,\s*\d+\].*?</answer>"
|
||||
completion_contents = [completion[0]["content"] for completion in completions]
|
||||
matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
|
||||
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
||||
if os.getenv("DEBUG_MODE") == "true":
|
||||
log_path = os.getenv("LOG_PATH")
|
||||
with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f:
|
||||
f.write(f"------------- {current_time} Format reward -------------\n")
|
||||
for content, match in zip(completion_contents, matches):
|
||||
f.write(f"Content: {content}\n")
|
||||
f.write(f"Has format: {bool(match)}\n")
|
||||
return [1.0 if match else 0.0 for match in matches]
|
||||
|
||||
@staticmethod
|
||||
def iou_reward(completions, solution, **kwargs):
|
||||
"""Calculate IoU reward between predicted bounding box from InternVL model and ground truth bounding box."""
|
||||
"""Adopt soft iou reward here"""
|
||||
import re
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
def iou(box1, box2):
|
||||
inter_x1 = max(box1[0], box2[0])
|
||||
inter_y1 = max(box1[1], box2[1])
|
||||
inter_x2 = min(box1[2]-1, box2[2]-1)
|
||||
inter_y2 = min(box1[3]-1, box2[3]-1)
|
||||
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
|
||||
inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
|
||||
else:
|
||||
inter = 0
|
||||
union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
|
||||
return float(inter)/union
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
rewards = []
|
||||
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
||||
answer_tag_pattern = r'<answer>(.*?)</answer>'
|
||||
bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]'
|
||||
for content, sol in zip(contents, solution):
|
||||
sol = re.findall(answer_tag_pattern, sol, re.DOTALL)[-1]
|
||||
sol = json.loads(sol.strip())
|
||||
reward = 0.0
|
||||
# Try symbolic verification first
|
||||
try:
|
||||
content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
|
||||
if content_answer_match:
|
||||
content_answer = content_answer_match.group(1).strip()
|
||||
bbox_match = re.search(bbox_pattern, content_answer)
|
||||
if bbox_match:
|
||||
bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
|
||||
reward = iou(bbox, sol)
|
||||
except Exception:
|
||||
pass # Continue to next verification method if this fails
|
||||
|
||||
rewards.append(reward)
|
||||
if os.getenv("DEBUG_MODE") == "true":
|
||||
log_path = os.getenv("LOG_PATH")
|
||||
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
||||
image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None
|
||||
problem = kwargs.get("problem")[0]
|
||||
if reward <= 1.0: # this condition can be changed for debug
|
||||
with open(log_path, "a", encoding='utf-8') as f:
|
||||
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
||||
f.write(f"image_path: {image_path}\n")
|
||||
f.write(f"problem: {problem}\n")
|
||||
f.write(f"Content: {content}\n")
|
||||
f.write(f"Solution: {sol}\n")
|
||||
return rewards
|
||||
|
||||
@staticmethod
|
||||
def select_reward_func(func: str, task_type: str):
|
||||
if func == "accuracy":
|
||||
match task_type:
|
||||
case "rec":
|
||||
return InvernVLModule.iou_reward
|
||||
case _:
|
||||
raise ValueError(f"Unsupported reward function: {func}")
|
||||
elif func == "format":
|
||||
match task_type:
|
||||
case "rec":
|
||||
return InvernVLModule.format_reward_rec
|
||||
case _:
|
||||
raise ValueError(f"Unsupported reward function: {func}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported reward function: {func}")
|
||||
|
||||
|
||||
def process_conversation_list(conversation_list, system_message=None, image_newline=True):
|
||||
if system_message is not None:
|
||||
conversation_list = conversation_list[1:]
|
||||
processed_list = []
|
||||
|
||||
for item in conversation_list:
|
||||
role = item["role"]
|
||||
content = item["content"]
|
||||
|
||||
if isinstance(content, list):
|
||||
overall_str = ""
|
||||
for content_item in content:
|
||||
if content_item.get("type") == "image":
|
||||
overall_str += "<image>" if not image_newline else "<image>\n"
|
||||
elif content_item.get("type") == "text":
|
||||
overall_str += content_item.get("text")
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(content_item)}")
|
||||
processed_list.append(overall_str)
|
||||
elif isinstance(content, str):
|
||||
processed_list.append(content)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(content)}")
|
||||
|
||||
return processed_list
|
||||
|
||||
def extract_system_message(conversation_list):
|
||||
if conversation_list[0]["role"] == "system":
|
||||
if isinstance(conversation_list[0]["content"], list):
|
||||
return conversation_list[0]["content"][0]["text"]
|
||||
else:
|
||||
return conversation_list[0]["content"]
|
||||
return None
|
||||
|
||||
|
||||
def build_transform(input_size):
|
||||
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
||||
transform = T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=MEAN, std=STD)
|
||||
])
|
||||
return transform
|
||||
|
||||
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
||||
best_ratio_diff = float('inf')
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
if ratio_diff < best_ratio_diff:
|
||||
best_ratio_diff = ratio_diff
|
||||
best_ratio = ratio
|
||||
elif ratio_diff == best_ratio_diff:
|
||||
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
|
||||
orig_width, orig_height = image.size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = set(
|
||||
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
||||
i * j <= max_num and i * j >= min_num)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size
|
||||
)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
assert len(processed_images) == blocks
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
return processed_images
|
||||
@@ -0,0 +1,175 @@
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor
|
||||
from typing import Dict, Any, Union
|
||||
from trl.data_utils import maybe_apply_chat_template
|
||||
import torch
|
||||
|
||||
from open_r1.vlm_modules.vlm_module import VLMBaseModule
|
||||
|
||||
class Qwen2VLModule(VLMBaseModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_vlm_key(self):
|
||||
return "qwen"
|
||||
|
||||
def get_model_class(self, model_id: str, model_init_kwargs: dict):
|
||||
if "Qwen2-VL" in model_id:
|
||||
model_cls = Qwen2VLForConditionalGeneration
|
||||
elif "Qwen2.5-VL" in model_id:
|
||||
model_cls = Qwen2_5_VLForConditionalGeneration
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {model_id}")
|
||||
return model_cls
|
||||
|
||||
def post_model_init(self, model, processing_class):
|
||||
pass
|
||||
|
||||
def get_processing_class(self):
|
||||
return AutoProcessor
|
||||
|
||||
def get_vision_modules_keywords(self):
|
||||
return ['visual']
|
||||
|
||||
def get_custom_multimodal_keywords(self):
|
||||
return ['pixel_values', 'image_grid_thw']
|
||||
|
||||
def get_non_generate_params(self):
|
||||
return []
|
||||
|
||||
def get_custom_processing_keywords(self):
|
||||
return [('image_processor', 'max_pixels'), ('image_processor', 'min_pixels')]
|
||||
|
||||
def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
|
||||
prompts_text = [maybe_apply_chat_template(example, processing_class)["prompt"] for example in inputs]
|
||||
return prompts_text
|
||||
|
||||
def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False):
|
||||
# FIXME
|
||||
# This could only process pure-multimodal or pure-text inputs
|
||||
if len(images) > 0:
|
||||
prompt_inputs = processing_class(
|
||||
text=prompts_text,
|
||||
images=images,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
padding_side=padding_side,
|
||||
add_special_tokens=add_special_tokens)
|
||||
else:
|
||||
prompt_inputs = processing_class(
|
||||
text=prompts_text,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
padding_side=padding_side,
|
||||
add_special_tokens=add_special_tokens)
|
||||
return prompt_inputs
|
||||
|
||||
@staticmethod
|
||||
def get_question_template(task_type: str):
|
||||
match task_type:
|
||||
case "rec":
|
||||
return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
|
||||
case "ic":
|
||||
return "{Question} First thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> json format answer here </answer>"
|
||||
case "odLength":
|
||||
SYSTEM_PROMPT = (
|
||||
#"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
||||
"First thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
||||
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
||||
"<think> reasoning process here </think><answer> answer here </answer>"
|
||||
)
|
||||
return SYSTEM_PROMPT + '\n' + "{Question}"
|
||||
case _:
|
||||
return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags."
|
||||
|
||||
@staticmethod
|
||||
def format_reward_rec(completions, **kwargs):
|
||||
"""Check if the Qwen model output matches a specific format."""
|
||||
import re
|
||||
import os
|
||||
from datetime import datetime
|
||||
pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
|
||||
completion_contents = [completion[0]["content"] for completion in completions]
|
||||
matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
|
||||
|
||||
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
||||
if os.getenv("DEBUG_MODE") == "true":
|
||||
log_path = os.getenv("LOG_PATH")
|
||||
with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f:
|
||||
f.write(f"------------- {current_time} Format reward -------------\n")
|
||||
for content, match in zip(completion_contents, matches):
|
||||
f.write(f"Content: {content}\n")
|
||||
f.write(f"Has format: {bool(match)}\n")
|
||||
return [1.0 if match else 0.0 for match in matches]
|
||||
|
||||
@staticmethod
|
||||
def iou_reward(completions, solution, **kwargs):
|
||||
"""Calculate IoU reward between predicted bounding box from Qwen model and ground truth bounding box."""
|
||||
import re
|
||||
import os
|
||||
from datetime import datetime
|
||||
import json
|
||||
def iou(box1, box2):
|
||||
inter_x1 = max(box1[0], box2[0])
|
||||
inter_y1 = max(box1[1], box2[1])
|
||||
inter_x2 = min(box1[2]-1, box2[2]-1)
|
||||
inter_y2 = min(box1[3]-1, box2[3]-1)
|
||||
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
|
||||
inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
|
||||
else:
|
||||
inter = 0
|
||||
union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
|
||||
return float(inter)/union
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
rewards = []
|
||||
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
||||
answer_tag_pattern = r'<answer>(.*?)</answer>'
|
||||
bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]'
|
||||
for content, sol in zip(contents, solution):
|
||||
sol = re.findall(answer_tag_pattern, sol, re.DOTALL)[-1]
|
||||
sol = json.loads(sol.strip())
|
||||
reward = 0.0
|
||||
# Try symbolic verification first
|
||||
try:
|
||||
content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
|
||||
if content_answer_match:
|
||||
content_answer = content_answer_match.group(1).strip()
|
||||
bbox_match = re.search(bbox_pattern, content_answer)
|
||||
if bbox_match:
|
||||
bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
|
||||
# if iou(bbox, sol) > 0.5:
|
||||
# reward = 1.0
|
||||
reward = iou(bbox, sol)
|
||||
except Exception:
|
||||
pass # Continue to next verification method if this fails
|
||||
|
||||
rewards.append(reward)
|
||||
if os.getenv("DEBUG_MODE") == "true":
|
||||
log_path = os.getenv("LOG_PATH")
|
||||
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
||||
image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None
|
||||
problem = kwargs.get("problem")[0]
|
||||
if reward <= 1.0: # this condition can be changed for debug
|
||||
with open(log_path, "a", encoding='utf-8') as f:
|
||||
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
||||
f.write(f"image_path: {image_path}\n")
|
||||
f.write(f"problem: {problem}\n")
|
||||
f.write(f"Content: {content}\n")
|
||||
f.write(f"Solution: {sol}\n")
|
||||
return rewards
|
||||
|
||||
@staticmethod
|
||||
def select_reward_func(func: str, task_type: str):
|
||||
if func == "accuracy":
|
||||
match task_type:
|
||||
case "rec":
|
||||
return Qwen2VLModule.iou_reward
|
||||
case _:
|
||||
raise ValueError(f"Unsupported reward function: {func}")
|
||||
elif func == "format":
|
||||
match task_type:
|
||||
case "rec":
|
||||
return Qwen2VLModule.format_reward_rec
|
||||
case _:
|
||||
raise ValueError(f"Unsupported reward function: {func}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported reward function: {func}")
|
||||
@@ -0,0 +1,50 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Union
|
||||
import torch
|
||||
|
||||
|
||||
class VLMBaseModule(ABC):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def get_vlm_key(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_class(self, model_id: str, model_init_kwargs: dict):
|
||||
pass
|
||||
|
||||
def post_model_init(self, model, processing_class):
|
||||
pass
|
||||
|
||||
def is_embeds_input(self):
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def get_processing_class(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_vision_modules_keywords(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_custom_multimodal_keywords(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_non_generate_params(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_custom_processing_keywords(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors, padding, padding_side, add_special_tokens):
|
||||
pass
|
||||