Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support DPO, ORPO and Reward Model #743

Merged
merged 46 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
9f2e35e
Support reward model and dpo
RangiLyu May 22, 2024
988fcaa
support train reward model
RangiLyu May 22, 2024
28663c1
fix config
RangiLyu May 22, 2024
96d6b00
fix lint
RangiLyu May 22, 2024
0015b42
fix lint
RangiLyu May 22, 2024
f8353d8
support jsonl dataset
RangiLyu May 24, 2024
f03dbc5
feat: support ORPO
RangiLyu May 27, 2024
e5c52a6
reorg configs
RangiLyu Jun 3, 2024
805bd5a
rename collate function
RangiLyu Jun 3, 2024
830cab5
rename collate function
RangiLyu Jun 3, 2024
1212f19
use varlen attention in validation
RangiLyu Jun 3, 2024
adee459
fix lint
RangiLyu Jun 3, 2024
c042c55
fix lint
RangiLyu Jun 3, 2024
08483c7
rebase main
RangiLyu Jun 3, 2024
6d3f1ec
update
RangiLyu Jun 3, 2024
b2589e8
add reference and update dpo loss
RangiLyu Jun 3, 2024
00a8d82
inherit sft
RangiLyu Jun 5, 2024
6c43a43
fix broadcast
RangiLyu Jun 6, 2024
4d0c96d
fix nan loss skip
RangiLyu Jun 7, 2024
bfead3b
support reward model sp
HIT-cwh Jun 7, 2024
5aafdb9
support dpo sp
HIT-cwh Jun 7, 2024
c571a70
support orpo sp
HIT-cwh Jun 7, 2024
4a79d2b
fix bugs
HIT-cwh Jun 11, 2024
3d26ad2
fix rebase
RangiLyu Jun 11, 2024
9afed8c
convert script
RangiLyu Jun 11, 2024
aba3646
fix precommit
RangiLyu Jun 11, 2024
776037d
mv convert script to model
RangiLyu Jun 11, 2024
06004fd
fix version check
RangiLyu Jun 11, 2024
e990be3
fix import
RangiLyu Jun 11, 2024
2952593
add comments of reward token
RangiLyu Jun 11, 2024
036e7f7
fix orpo cfg
RangiLyu Jun 11, 2024
aeaa98c
fix lint
RangiLyu Jun 11, 2024
114e5e3
fix lint
RangiLyu Jun 11, 2024
a885ade
remove seed
RangiLyu Jun 11, 2024
582e02f
remove seed
RangiLyu Jun 11, 2024
a26dcd9
add sp config
RangiLyu Jun 11, 2024
e9000b0
add reward sp config
RangiLyu Jun 11, 2024
fb656d7
fix convert
RangiLyu Jun 12, 2024
6ca9ad9
fix lora reward model convert
RangiLyu Jun 12, 2024
1a28acd
fix qlora reward merge
RangiLyu Jun 12, 2024
7a978fc
update dpo loss
RangiLyu Jun 12, 2024
c74721c
log reward acc and margin in dpo
RangiLyu Jun 12, 2024
5c711ae
update logits mask
RangiLyu Jun 12, 2024
88ec30c
unpack logits first
RangiLyu Jun 12, 2024
ddf5fa4
more loss setting in dpo cfgs
RangiLyu Jun 12, 2024
9d60565
more loss setting in orpo cfgs
RangiLyu Jun 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config-zh-cn.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
exclude: ^tests/data/
exclude: ^tests/data/|^xtuner/tools/model_converters/modeling_internlm2_reward/
repos:
- repo: https://gitee.com/openmmlab/mirrors-flake8
rev: 5.0.4
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
exclude: ^tests/data/
exclude: ^tests/data/|^xtuner/tools/model_converters/modeling_internlm2_reward/
repos:
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
Expand Down
201 changes: 201 additions & 0 deletions xtuner/configs/dpo/internlm/internlm2_chat_1_8b_dpo_full.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright (c) OpenMMLab. All rights reserved.
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer

from xtuner.dataset.collate_fns.preference_collate_fn import \
preference_collate_fn
from xtuner.dataset.preference_dataset import (build_preference_dataset,
orpo_dpo_mix_40k_map_fn)
from xtuner.engine.hooks import (EvaluateChatHook,
VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model.dpo import DPO
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE

#######################################################################
# PART 1 Settings #
#######################################################################
# Model
pretrained_model_name_or_path = 'internlm/internlm2-chat-1_8b-sft'
use_varlen_attn = False
dpo_loss_type = 'sigmoid' # One of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'sppo_hard', 'nca_pair', 'robust'] # noqa: E501
loss_beta = 0.1
label_smoothing = 0.0

# Data
prompt_template = PROMPT_TEMPLATE.internlm2_chat
max_length = 2048

# Scheduler & Optimizer
batch_size = 1 # per_device
accumulative_counts = 16
dataloader_num_workers = 0
max_epochs = 3
optim_type = AdamW
lr = 5e-7 # refer to alignment handbook
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1 # grad clip
warmup_ratio = 0.03

# Save
save_steps = 500
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)

# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = SYSTEM_TEMPLATE.alpaca
evaluation_inputs = [
'What famous British author, known for his tales of mystery and the macabre, shares his initials with a common abbreviation for "rest in peace"?', # noqa: E501
'Please tell me five scenic spots in Shanghai',
'890729 - 425663? Only respond with math and no words.'
]

#######################################################################
# PART 2 Model & Tokenizer #
#######################################################################
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
padding_side='right')

model = dict(
type=DPO,
use_varlen_attn=use_varlen_attn,
loss_type=dpo_loss_type,
beta=loss_beta,
label_smoothing=label_smoothing,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True))

#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
train_dataset = dict(
type=build_preference_dataset,
dataset=dict(type=load_dataset, path='mlabonne/orpo-dpo-mix-40k'),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=orpo_dpo_mix_40k_map_fn,
is_dpo=True,
is_reward=False,
reward_token_id=-1,
num_proc=32,
use_varlen_attn=use_varlen_attn,
shuffle_before_pack=True,
)

train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=train_dataset,
sampler=dict(type=DefaultSampler, shuffle=True),
collate_fn=dict(
type=preference_collate_fn, use_varlen_attn=use_varlen_attn))

#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = [
dict(
type=LinearLR,
start_factor=1e-5,
by_epoch=True,
begin=0,
end=warmup_ratio * max_epochs,
convert_to_iter_based=True),
dict(
type=CosineAnnealingLR,
eta_min=0.0,
by_epoch=True,
begin=warmup_ratio * max_epochs,
end=max_epochs,
convert_to_iter_based=True)
]

# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
# dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
every_n_iters=evaluation_freq,
evaluation_inputs=evaluation_inputs,
system=SYSTEM,
prompt_template=prompt_template)
]

if use_varlen_attn:
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]

# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 10 iterations.
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per `save_steps`.
checkpoint=dict(
type=CheckpointHook,
by_epoch=False,
interval=save_steps,
max_keep_ckpts=save_total_limit),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)

# set log processor
log_processor = dict(by_epoch=False)
Loading
Loading