mirror of
https://github.com/leigest519/ScreenCoder.git
synced 2026-02-13 02:02:48 +00:00
Add post-training folder
This commit is contained in:
38
post-training/LLaMA-Factory/tests/check_license.py
Normal file
38
post-training/LLaMA-Factory/tests/check_license.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
KEYWORDS = ("Copyright", "2025", "LlamaFactory")
|
||||
|
||||
|
||||
def main():
|
||||
path_list: list[Path] = []
|
||||
for check_dir in sys.argv[1:]:
|
||||
path_list.extend(Path(check_dir).glob("**/*.py"))
|
||||
|
||||
for path in path_list:
|
||||
with open(path.absolute(), encoding="utf-8") as f:
|
||||
file_content = f.read().strip().split("\n")
|
||||
if not file_content[0]:
|
||||
continue
|
||||
|
||||
print(f"Check license: {path}")
|
||||
assert all(keyword in file_content[0] for keyword in KEYWORDS), f"File {path} does not contain license."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
49
post-training/LLaMA-Factory/tests/e2e/test_chat.py
Normal file
49
post-training/LLaMA-Factory/tests/e2e/test_chat.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from llamafactory.chat import ChatModel
|
||||
|
||||
|
||||
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
INFER_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA3,
|
||||
"finetuning_type": "lora",
|
||||
"template": "llama3",
|
||||
"infer_dtype": "float16",
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 1,
|
||||
}
|
||||
|
||||
MESSAGES = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
]
|
||||
|
||||
EXPECTED_RESPONSE = "_rho"
|
||||
|
||||
|
||||
def test_chat():
|
||||
chat_model = ChatModel(INFER_ARGS)
|
||||
assert chat_model.chat(MESSAGES)[0].response_text == EXPECTED_RESPONSE
|
||||
|
||||
|
||||
def test_stream_chat():
|
||||
chat_model = ChatModel(INFER_ARGS)
|
||||
response = ""
|
||||
for token in chat_model.stream_chat(MESSAGES):
|
||||
response += token
|
||||
|
||||
assert response == EXPECTED_RESPONSE
|
||||
71
post-training/LLaMA-Factory/tests/e2e/test_sglang.py
Normal file
71
post-training/LLaMA-Factory/tests/e2e/test_sglang.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from llamafactory.chat import ChatModel
|
||||
from llamafactory.extras.packages import is_sglang_available
|
||||
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
|
||||
INFER_ARGS = {
|
||||
"model_name_or_path": MODEL_NAME,
|
||||
"finetuning_type": "lora",
|
||||
"template": "llama3",
|
||||
"infer_dtype": "float16",
|
||||
"infer_backend": "sglang",
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 1,
|
||||
}
|
||||
|
||||
|
||||
MESSAGES = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed")
|
||||
def test_chat():
|
||||
r"""Test the SGLang engine's basic chat functionality."""
|
||||
chat_model = ChatModel(INFER_ARGS)
|
||||
response = chat_model.chat(MESSAGES)[0]
|
||||
# TODO: Change to EXPECTED_RESPONSE
|
||||
print(response.response_text)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed")
|
||||
def test_stream_chat():
|
||||
r"""Test the SGLang engine's streaming chat functionality."""
|
||||
chat_model = ChatModel(INFER_ARGS)
|
||||
|
||||
response = ""
|
||||
for token in chat_model.stream_chat(MESSAGES):
|
||||
response += token
|
||||
|
||||
print("Complete response:", response)
|
||||
assert response, "Should receive a non-empty response"
|
||||
|
||||
|
||||
# Run tests if executed directly
|
||||
if __name__ == "__main__":
|
||||
if not is_sglang_available():
|
||||
print("SGLang is not available. Please install it.")
|
||||
sys.exit(1)
|
||||
|
||||
test_chat()
|
||||
test_stream_chat()
|
||||
71
post-training/LLaMA-Factory/tests/e2e/test_train.py
Normal file
71
post-training/LLaMA-Factory/tests/e2e/test_train.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from llamafactory.train.tuner import export_model, run_exp
|
||||
|
||||
|
||||
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
|
||||
|
||||
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TINY_LLAMA_ADAPTER = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
|
||||
|
||||
TRAIN_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA3,
|
||||
"do_train": True,
|
||||
"finetuning_type": "lora",
|
||||
"dataset_dir": "REMOTE:" + DEMO_DATA,
|
||||
"template": "llama3",
|
||||
"cutoff_len": 1,
|
||||
"overwrite_output_dir": True,
|
||||
"per_device_train_batch_size": 1,
|
||||
"max_steps": 1,
|
||||
"report_to": "none",
|
||||
}
|
||||
|
||||
INFER_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA3,
|
||||
"adapter_name_or_path": TINY_LLAMA_ADAPTER,
|
||||
"finetuning_type": "lora",
|
||||
"template": "llama3",
|
||||
"infer_dtype": "float16",
|
||||
}
|
||||
|
||||
OS_NAME = os.getenv("OS_NAME", "")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"stage,dataset",
|
||||
[
|
||||
("pt", "c4_demo"),
|
||||
("sft", "alpaca_en_demo"),
|
||||
("dpo", "dpo_en_demo"),
|
||||
("kto", "kto_en_demo"),
|
||||
pytest.param("rm", "dpo_en_demo", marks=pytest.mark.xfail(OS_NAME.startswith("windows"), reason="OS error.")),
|
||||
],
|
||||
)
|
||||
def test_run_exp(stage: str, dataset: str):
|
||||
output_dir = os.path.join("output", f"train_{stage}")
|
||||
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
|
||||
assert os.path.exists(output_dir)
|
||||
|
||||
|
||||
def test_export():
|
||||
export_dir = os.path.join("output", "llama3_export")
|
||||
export_model({"export_dir": export_dir, **INFER_ARGS})
|
||||
assert os.path.exists(export_dir)
|
||||
91
post-training/LLaMA-Factory/tests/eval/test_eval_template.py
Normal file
91
post-training/LLaMA-Factory/tests/eval/test_eval_template.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from llamafactory.eval.template import get_eval_template
|
||||
|
||||
|
||||
def test_eval_template_en():
|
||||
support_set = [
|
||||
{
|
||||
"question": "Fewshot question",
|
||||
"A": "Fewshot1",
|
||||
"B": "Fewshot2",
|
||||
"C": "Fewshot3",
|
||||
"D": "Fewshot4",
|
||||
"answer": "B",
|
||||
}
|
||||
]
|
||||
example = {
|
||||
"question": "Target question",
|
||||
"A": "Target1",
|
||||
"B": "Target2",
|
||||
"C": "Target3",
|
||||
"D": "Target4",
|
||||
"answer": "C",
|
||||
}
|
||||
template = get_eval_template(name="en")
|
||||
messages = template.format_example(example, support_set=support_set, subject_name="SubName")
|
||||
assert messages == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"The following are multiple choice questions (with answers) about SubName.\n\n"
|
||||
"Fewshot question\nA. Fewshot1\nB. Fewshot2\nC. Fewshot3\nD. Fewshot4\nAnswer:"
|
||||
),
|
||||
},
|
||||
{"role": "assistant", "content": "B"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Target question\nA. Target1\nB. Target2\nC. Target3\nD. Target4\nAnswer:",
|
||||
},
|
||||
{"role": "assistant", "content": "C"},
|
||||
]
|
||||
|
||||
|
||||
def test_eval_template_zh():
|
||||
support_set = [
|
||||
{
|
||||
"question": "示例问题",
|
||||
"A": "示例答案1",
|
||||
"B": "示例答案2",
|
||||
"C": "示例答案3",
|
||||
"D": "示例答案4",
|
||||
"answer": "B",
|
||||
}
|
||||
]
|
||||
example = {
|
||||
"question": "目标问题",
|
||||
"A": "目标答案1",
|
||||
"B": "目标答案2",
|
||||
"C": "目标答案3",
|
||||
"D": "目标答案4",
|
||||
"answer": "C",
|
||||
}
|
||||
template = get_eval_template(name="zh")
|
||||
messages = template.format_example(example, support_set=support_set, subject_name="主题")
|
||||
assert messages == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"以下是中国关于主题考试的单项选择题,请选出其中的正确答案。\n\n"
|
||||
"示例问题\nA. 示例答案1\nB. 示例答案2\nC. 示例答案3\nD. 示例答案4\n答案:"
|
||||
),
|
||||
},
|
||||
{"role": "assistant", "content": "B"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "目标问题\nA. 目标答案1\nB. 目标答案2\nC. 目标答案3\nD. 目标答案4\n答案:",
|
||||
},
|
||||
{"role": "assistant", "content": "C"},
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.train.test_utils import load_infer_model
|
||||
|
||||
|
||||
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
INFER_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA3,
|
||||
"template": "llama3",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.xfail(is_transformers_version_greater_than("4.48"), reason="Attention refactor.")
|
||||
def test_attention():
|
||||
attention_available = ["disabled"]
|
||||
if is_torch_sdpa_available():
|
||||
attention_available.append("sdpa")
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
attention_available.append("fa2")
|
||||
|
||||
llama_attention_classes = {
|
||||
"disabled": "LlamaAttention",
|
||||
"sdpa": "LlamaSdpaAttention",
|
||||
"fa2": "LlamaFlashAttention2",
|
||||
}
|
||||
for requested_attention in attention_available:
|
||||
model = load_infer_model(flash_attn=requested_attention, **INFER_ARGS)
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
assert module.__class__.__name__ == llama_attention_classes[requested_attention]
|
||||
@@ -0,0 +1,66 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from llamafactory.extras.misc import get_current_device
|
||||
from llamafactory.train.test_utils import load_train_model
|
||||
|
||||
|
||||
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TRAIN_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA3,
|
||||
"stage": "sft",
|
||||
"do_train": True,
|
||||
"finetuning_type": "lora",
|
||||
"lora_target": "all",
|
||||
"dataset": "llamafactory/tiny-supervised-dataset",
|
||||
"dataset_dir": "ONLINE",
|
||||
"template": "llama3",
|
||||
"cutoff_len": 1024,
|
||||
"output_dir": "dummy_dir",
|
||||
"overwrite_output_dir": True,
|
||||
"fp16": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("disable_gradient_checkpointing", [False, True])
|
||||
def test_vanilla_checkpointing(disable_gradient_checkpointing: bool):
|
||||
model = load_train_model(disable_gradient_checkpointing=disable_gradient_checkpointing, **TRAIN_ARGS)
|
||||
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||
assert getattr(module, "gradient_checkpointing") != disable_gradient_checkpointing
|
||||
|
||||
|
||||
def test_unsloth_gradient_checkpointing():
|
||||
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
|
||||
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
|
||||
|
||||
|
||||
def test_upcast_layernorm():
|
||||
model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and "norm" in name:
|
||||
assert param.dtype == torch.float32
|
||||
|
||||
|
||||
def test_upcast_lmhead_output():
|
||||
model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS)
|
||||
inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())
|
||||
outputs: torch.Tensor = model.get_output_embeddings()(inputs)
|
||||
assert outputs.dtype == torch.float32
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from llamafactory.model.model_utils.misc import find_expanded_modules
|
||||
|
||||
|
||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_expanded_modules():
|
||||
config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
|
||||
expanded_modules = find_expanded_modules(model, ["q_proj", "v_proj"], num_layer_trainable=4)
|
||||
assert expanded_modules == [
|
||||
"model.layers.7.self_attn.q_proj",
|
||||
"model.layers.7.self_attn.v_proj",
|
||||
"model.layers.15.self_attn.q_proj",
|
||||
"model.layers.15.self_attn.v_proj",
|
||||
"model.layers.23.self_attn.q_proj",
|
||||
"model.layers.23.self_attn.v_proj",
|
||||
"model.layers.31.self_attn.q_proj",
|
||||
"model.layers.31.self_attn.v_proj",
|
||||
]
|
||||
@@ -0,0 +1,68 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from llamafactory.model.model_utils.packing import get_seqlens_in_batch, get_unpad_data
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attention_mask,golden_seq_lens",
|
||||
[
|
||||
(
|
||||
[
|
||||
[1, 1, 2, 2, 2, 0],
|
||||
[1, 2, 2, 3, 3, 3],
|
||||
],
|
||||
[2, 3, 1, 2, 3],
|
||||
),
|
||||
(
|
||||
[[1]],
|
||||
[1],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_seqlens_in_batch(attention_mask, golden_seq_lens):
|
||||
attention_mask_with_indices = torch.tensor(attention_mask)
|
||||
seqlens_in_batch = get_seqlens_in_batch(attention_mask_with_indices)
|
||||
assert torch.all(seqlens_in_batch == torch.tensor(golden_seq_lens))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attention_mask,golden_indices,golden_cu_seqlens,golden_max_seqlen",
|
||||
[
|
||||
(
|
||||
[
|
||||
[1, 1, 2, 2, 2, 0],
|
||||
[1, 2, 2, 3, 3, 3],
|
||||
],
|
||||
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11],
|
||||
[0, 2, 5, 6, 8, 11],
|
||||
3,
|
||||
),
|
||||
(
|
||||
[[1]],
|
||||
[0],
|
||||
[0, 1],
|
||||
1,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_unpad_data(attention_mask, golden_indices, golden_cu_seqlens, golden_max_seqlen):
|
||||
attention_mask_with_indices = torch.tensor(attention_mask)
|
||||
indices, cu_seqlens, max_seqlen_in_batch = get_unpad_data(attention_mask_with_indices)
|
||||
assert torch.all(indices == torch.tensor(golden_indices))
|
||||
assert torch.all(cu_seqlens == torch.tensor(golden_cu_seqlens, dtype=torch.int32))
|
||||
assert max_seqlen_in_batch == golden_max_seqlen
|
||||
@@ -0,0 +1,70 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForVision2Seq
|
||||
|
||||
from llamafactory.hparams import FinetuningArguments, ModelArguments
|
||||
from llamafactory.model.adapter import init_adapter
|
||||
|
||||
|
||||
@pytest.mark.parametrize("freeze_vision_tower", (False, True))
|
||||
@pytest.mark.parametrize("freeze_multi_modal_projector", (False, True))
|
||||
@pytest.mark.parametrize("freeze_language_model", (False, True))
|
||||
def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bool, freeze_language_model: bool):
|
||||
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")
|
||||
finetuning_args = FinetuningArguments(
|
||||
finetuning_type="full",
|
||||
freeze_vision_tower=freeze_vision_tower,
|
||||
freeze_multi_modal_projector=freeze_multi_modal_projector,
|
||||
freeze_language_model=freeze_language_model,
|
||||
)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForVision2Seq.from_config(config)
|
||||
|
||||
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True)
|
||||
for name, param in model.named_parameters():
|
||||
if any(key in name for key in ["visual.patch_embed", "visual.blocks"]):
|
||||
assert param.requires_grad != freeze_vision_tower
|
||||
elif "visual.merger" in name:
|
||||
assert param.requires_grad != freeze_multi_modal_projector
|
||||
else:
|
||||
assert param.requires_grad != freeze_language_model
|
||||
|
||||
|
||||
@pytest.mark.parametrize("freeze_vision_tower", (False, True))
|
||||
def test_visual_lora(freeze_vision_tower: bool):
|
||||
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="lora", freeze_vision_tower=freeze_vision_tower)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForVision2Seq.from_config(config)
|
||||
|
||||
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True)
|
||||
trainable_params, frozen_params = set(), set()
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
trainable_params.add(name)
|
||||
else:
|
||||
frozen_params.add(name)
|
||||
|
||||
if freeze_vision_tower:
|
||||
assert "base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight" not in trainable_params
|
||||
else:
|
||||
assert "base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight" in trainable_params
|
||||
|
||||
assert "merger" not in trainable_params
|
||||
assert "base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight" in trainable_params
|
||||
48
post-training/LLaMA-Factory/tests/model/test_base.py
Normal file
48
post-training/LLaMA-Factory/tests/model/test_base.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, patch_valuehead_model
|
||||
|
||||
|
||||
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
|
||||
|
||||
INFER_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA3,
|
||||
"template": "llama3",
|
||||
"infer_dtype": "float16",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fix_valuehead_cpu_loading():
|
||||
patch_valuehead_model()
|
||||
|
||||
|
||||
def test_base():
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA3)
|
||||
compare_model(model, ref_model)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
|
||||
def test_valuehead():
|
||||
model = load_infer_model(add_valuehead=True, **INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA_VALUEHEAD, add_valuehead=True)
|
||||
compare_model(model, ref_model)
|
||||
72
post-training/LLaMA-Factory/tests/model/test_freeze.py
Normal file
72
post-training/LLaMA-Factory/tests/model/test_freeze.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from llamafactory.train.test_utils import load_infer_model, load_train_model
|
||||
|
||||
|
||||
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TRAIN_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA3,
|
||||
"stage": "sft",
|
||||
"do_train": True,
|
||||
"finetuning_type": "freeze",
|
||||
"dataset": "llamafactory/tiny-supervised-dataset",
|
||||
"dataset_dir": "ONLINE",
|
||||
"template": "llama3",
|
||||
"cutoff_len": 1024,
|
||||
"output_dir": "dummy_dir",
|
||||
"overwrite_output_dir": True,
|
||||
"fp16": True,
|
||||
}
|
||||
|
||||
INFER_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA3,
|
||||
"finetuning_type": "freeze",
|
||||
"template": "llama3",
|
||||
"infer_dtype": "float16",
|
||||
}
|
||||
|
||||
|
||||
def test_freeze_train_all_modules():
|
||||
model = load_train_model(freeze_trainable_layers=1, **TRAIN_ARGS)
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith("model.layers.1."):
|
||||
assert param.requires_grad is True
|
||||
assert param.dtype == torch.float32
|
||||
else:
|
||||
assert param.requires_grad is False
|
||||
assert param.dtype == torch.float16
|
||||
|
||||
|
||||
def test_freeze_train_extra_modules():
|
||||
model = load_train_model(freeze_trainable_layers=1, freeze_extra_modules="embed_tokens,lm_head", **TRAIN_ARGS)
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith("model.layers.1.") or any(module in name for module in ["embed_tokens", "lm_head"]):
|
||||
assert param.requires_grad is True
|
||||
assert param.dtype == torch.float32
|
||||
else:
|
||||
assert param.requires_grad is False
|
||||
assert param.dtype == torch.float16
|
||||
|
||||
|
||||
def test_freeze_inference():
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
for param in model.parameters():
|
||||
assert param.requires_grad is False
|
||||
assert param.dtype == torch.float16
|
||||
57
post-training/LLaMA-Factory/tests/model/test_full.py
Normal file
57
post-training/LLaMA-Factory/tests/model/test_full.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from llamafactory.train.test_utils import load_infer_model, load_train_model
|
||||
|
||||
|
||||
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TRAIN_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA3,
|
||||
"stage": "sft",
|
||||
"do_train": True,
|
||||
"finetuning_type": "full",
|
||||
"dataset": "llamafactory/tiny-supervised-dataset",
|
||||
"dataset_dir": "ONLINE",
|
||||
"template": "llama3",
|
||||
"cutoff_len": 1024,
|
||||
"output_dir": "dummy_dir",
|
||||
"overwrite_output_dir": True,
|
||||
"fp16": True,
|
||||
}
|
||||
|
||||
INFER_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA3,
|
||||
"finetuning_type": "full",
|
||||
"template": "llama3",
|
||||
"infer_dtype": "float16",
|
||||
}
|
||||
|
||||
|
||||
def test_full_train():
|
||||
model = load_train_model(**TRAIN_ARGS)
|
||||
for param in model.parameters():
|
||||
assert param.requires_grad is True
|
||||
assert param.dtype == torch.float32
|
||||
|
||||
|
||||
def test_full_inference():
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
for param in model.parameters():
|
||||
assert param.requires_grad is False
|
||||
assert param.dtype == torch.float16
|
||||
109
post-training/LLaMA-Factory/tests/model/test_lora.py
Normal file
109
post-training/LLaMA-Factory/tests/model/test_lora.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from llamafactory.train.test_utils import (
|
||||
check_lora_model,
|
||||
compare_model,
|
||||
load_infer_model,
|
||||
load_reference_model,
|
||||
load_train_model,
|
||||
patch_valuehead_model,
|
||||
)
|
||||
|
||||
|
||||
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TINY_LLAMA_ADAPTER = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
|
||||
|
||||
TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
|
||||
|
||||
TRAIN_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA3,
|
||||
"stage": "sft",
|
||||
"do_train": True,
|
||||
"finetuning_type": "lora",
|
||||
"dataset": "llamafactory/tiny-supervised-dataset",
|
||||
"dataset_dir": "ONLINE",
|
||||
"template": "llama3",
|
||||
"cutoff_len": 1024,
|
||||
"output_dir": "dummy_dir",
|
||||
"overwrite_output_dir": True,
|
||||
"fp16": True,
|
||||
}
|
||||
|
||||
INFER_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA3,
|
||||
"adapter_name_or_path": TINY_LLAMA_ADAPTER,
|
||||
"finetuning_type": "lora",
|
||||
"template": "llama3",
|
||||
"infer_dtype": "float16",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fix_valuehead_cpu_loading():
|
||||
patch_valuehead_model()
|
||||
|
||||
|
||||
def test_lora_train_qv_modules():
|
||||
model = load_train_model(lora_target="q_proj,v_proj", **TRAIN_ARGS)
|
||||
linear_modules, _ = check_lora_model(model)
|
||||
assert linear_modules == {"q_proj", "v_proj"}
|
||||
|
||||
|
||||
def test_lora_train_all_modules():
|
||||
model = load_train_model(lora_target="all", **TRAIN_ARGS)
|
||||
linear_modules, _ = check_lora_model(model)
|
||||
assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"}
|
||||
|
||||
|
||||
def test_lora_train_extra_modules():
|
||||
model = load_train_model(additional_target="embed_tokens,lm_head", **TRAIN_ARGS)
|
||||
_, extra_modules = check_lora_model(model)
|
||||
assert extra_modules == {"embed_tokens", "lm_head"}
|
||||
|
||||
|
||||
def test_lora_train_old_adapters():
|
||||
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=False, **TRAIN_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
|
||||
compare_model(model, ref_model)
|
||||
|
||||
|
||||
def test_lora_train_new_adapters():
|
||||
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=True, **TRAIN_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
|
||||
compare_model(
|
||||
model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
|
||||
def test_lora_train_valuehead():
|
||||
model = load_train_model(add_valuehead=True, **TRAIN_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA_VALUEHEAD, is_trainable=True, add_valuehead=True)
|
||||
state_dict = model.state_dict()
|
||||
ref_state_dict = ref_model.state_dict()
|
||||
assert torch.allclose(state_dict["v_head.summary.weight"], ref_state_dict["v_head.summary.weight"])
|
||||
assert torch.allclose(state_dict["v_head.summary.bias"], ref_state_dict["v_head.summary.bias"])
|
||||
|
||||
|
||||
def test_lora_inference():
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload()
|
||||
compare_model(model, ref_model)
|
||||
64
post-training/LLaMA-Factory/tests/model/test_pissa.py
Normal file
64
post-training/LLaMA-Factory/tests/model/test_pissa.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, load_train_model
|
||||
|
||||
|
||||
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TINY_LLAMA_PISSA = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-pissa")
|
||||
|
||||
TRAIN_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA3,
|
||||
"stage": "sft",
|
||||
"do_train": True,
|
||||
"finetuning_type": "lora",
|
||||
"pissa_init": True,
|
||||
"pissa_iter": -1,
|
||||
"dataset": "llamafactory/tiny-supervised-dataset",
|
||||
"dataset_dir": "ONLINE",
|
||||
"template": "llama3",
|
||||
"cutoff_len": 1024,
|
||||
"output_dir": "dummy_dir",
|
||||
"overwrite_output_dir": True,
|
||||
"fp16": True,
|
||||
}
|
||||
|
||||
INFER_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA_PISSA,
|
||||
"adapter_name_or_path": TINY_LLAMA_PISSA,
|
||||
"adapter_folder": "pissa_init",
|
||||
"finetuning_type": "lora",
|
||||
"template": "llama3",
|
||||
"infer_dtype": "float16",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="PiSSA initialization is not stable in different platform.")
|
||||
def test_pissa_train():
|
||||
model = load_train_model(**TRAIN_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=True)
|
||||
compare_model(model, ref_model)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Known connection error.")
|
||||
def test_pissa_inference():
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=False)
|
||||
ref_model = ref_model.merge_and_unload()
|
||||
compare_model(model, ref_model)
|
||||
85
post-training/LLaMA-Factory/tests/train/test_sft_trainer.py
Normal file
85
post-training/LLaMA-Factory/tests/train/test_sft_trainer.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
||||
from llamafactory.hparams import get_train_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.sft.trainer import CustomSeq2SeqTrainer
|
||||
|
||||
|
||||
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
|
||||
|
||||
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TRAIN_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA3,
|
||||
"stage": "sft",
|
||||
"do_train": True,
|
||||
"finetuning_type": "lora",
|
||||
"dataset": "llamafactory/tiny-supervised-dataset",
|
||||
"dataset_dir": "ONLINE",
|
||||
"template": "llama3",
|
||||
"cutoff_len": 1024,
|
||||
"overwrite_output_dir": True,
|
||||
"per_device_train_batch_size": 1,
|
||||
"max_steps": 1,
|
||||
"report_to": "none",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorWithVerbose(DataCollatorWithPadding):
|
||||
verbose_list: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
self.verbose_list.extend(features)
|
||||
batch = super().__call__(features)
|
||||
return {k: v[:, :1] for k, v in batch.items()} # truncate input length
|
||||
|
||||
|
||||
@pytest.mark.parametrize("disable_shuffling", [False, True])
|
||||
def test_shuffle(disable_shuffling: bool):
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||
{
|
||||
"output_dir": os.path.join("output", f"shuffle{str(disable_shuffling).lower()}"),
|
||||
"disable_shuffling": disable_shuffling,
|
||||
**TRAIN_ARGS,
|
||||
}
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
data_collator = DataCollatorWithVerbose(tokenizer=tokenizer)
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
)
|
||||
trainer.train()
|
||||
if disable_shuffling:
|
||||
assert data_collator.verbose_list[0]["input_ids"] == dataset_module["train_dataset"][0]["input_ids"]
|
||||
else:
|
||||
assert data_collator.verbose_list[0]["input_ids"] != dataset_module["train_dataset"][0]["input_ids"]
|
||||
2
post-training/LLaMA-Factory/tests/version.txt
Normal file
2
post-training/LLaMA-Factory/tests/version.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
# change if test fails or cache is outdated
|
||||
0.9.3.103
|
||||
Reference in New Issue
Block a user