Skip to content

Commit

Permalink
add/fix pretrain_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
ShuaibinLi committed Jun 13, 2024
1 parent 19eeee4 commit f3334f0
Show file tree
Hide file tree
Showing 18 changed files with 724 additions and 789 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ data
*.pkl.json
*.log.json
work_dirs/
rlhf_trainlog*/

# Pytorch
*.pth
Expand Down
154 changes: 91 additions & 63 deletions examples/rlhf/four_model_8gpu.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,80 @@
import torch

MAX_PROMPT_LEN = 1024
MAX_ANSWER_LEN = 1024

PROMPT_BATCH_SIZE = 256
PRETRAIN_BATCH_SIZE = 32

GENERATE_MICRO_BATCH_SIZE = 16
AC_INFER_MICRO_BATCH_SIZE = 8
REF_INFER_MICRO_BATCH_SIZE = 8
TRAIN_MICRO_BATCH_SIZE = 2

ZERO_STAGE = 3
ACTOR_DP_SIZE = 2
CRITIC_DP_SIZE = 2
ACTOR_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE
) // ACTOR_DP_SIZE // TRAIN_MICRO_BATCH_SIZE
CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE

MODEL_DTYPE = 'auto'

tokenizer_config = dict(
pad_token_id=0,
eos_token_id=92542,
padding_side='left',
)

rollout_config = dict(
actor_micro_bs=32,
reward_micro_bs=32,
actor_micro_bs=GENERATE_MICRO_BATCH_SIZE,
reward_micro_bs=GENERATE_MICRO_BATCH_SIZE,
clip_reward_min=-5,
clip_reward_max=5,
max_new_tokens=10,
async_reward=True,
max_new_tokens=MAX_ANSWER_LEN,
write_to_file=True,
generate_kwargs={
'do_sample': True,
'temperature': 1.0,
'top_k': 0,
'top_p': 0.9,
'pad_token_id': 0,
'eos_token_id': 92542,
'early_stopping': True,
'num_beams': 1,
'min_new_tokens': 1,
})
'num_beams': 1,
'early_stopping': True,
'eos_token_id': 92542,
'pad_token_id': 0,
},
)

repeater_config = dict(
actor_micro_bs=8,
ref_micro_bs=8,
critic_micro_bs=32,
reward_scale=False,
fine_grained_rm=False,
value_ema=False,
actor_micro_bs=AC_INFER_MICRO_BATCH_SIZE,
critic_micro_bs=AC_INFER_MICRO_BATCH_SIZE,
ref_micro_bs=REF_INFER_MICRO_BATCH_SIZE,
kl_coeff=0.01,
gamma=1.0,
gae_lambda=0.99,
answer_end_id=92542,
norm_rewards=True,
)

train_config = dict(
ppo_minibatch=64,
value_minibatch=64,
actor_micro_bs=2,
critic_micro_bs=2,
pretrain_step=0,
save_interval=800,
actor_micro_bs=TRAIN_MICRO_BATCH_SIZE,
critic_micro_bs=TRAIN_MICRO_BATCH_SIZE,
ppo_loss_weight=1.0,
pretrain_loss_weight=0.5,
pretrain_step=20,
save_interval=40,
)

critic_model_path = 'internlm/internlm2-chat-1_8b-sft'

model_configs = dict(
actor=dict(
model_path='internlm/internlm2-chat-1_8b-sft',
model_type='actor',
use_flash_attn=False,
trainer_config=dict(
torch_dtype=MODEL_DTYPE,
trainer_type='huggingface',
torch_dtype=torch.float32,
use_flash_attn=True,
gradient_checkpointing=False,
train_kwargs=dict(
micro_bsz=1,
lr=1e-6,
Expand All @@ -65,14 +83,14 @@
loss_type='per_seq',
),
parallel=dict(
data=dict(size=2, mode='deepspeed'),
data=dict(size=ACTOR_DP_SIZE, mode='deepspeed'),
tensor=dict(size=1, mode='1d'),
pipeline=dict(size=1, interleaved_overlap=False),
sequence=False,
),
deepspeed_config={
'zero_optimization': {
'stage': 2,
'stage': ZERO_STAGE,
'offload_param': {
'device': 'none'
},
Expand All @@ -91,34 +109,21 @@
'data_types': {
'grad_accum_dtype': 'fp32'
},
'train_micro_batch_size_per_gpu': 2,
'gradient_accumulation_steps': 16,
'train_batch_size': 64
}),
generator_config=dict(shared_with_trainer=True, ),
),
reference=dict(
model_path='internlm/internlm2-chat-1_8b-sft',
model_type='reference',
use_flash_attn=False,
trainer_config=dict(
torch_dtype=torch.float32,
trainer_type='huggingface',
parallel=dict(
data=dict(size=2, mode='ddp'),
tensor=dict(size=1, mode='1d'),
pipeline=dict(size=1, interleaved_overlap=False),
sequence=False,
),
'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE,
'gradient_accumulation_steps': ACTOR_GRADIENT_ACC_STEP,
'train_batch_size': PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE,
},
),
generator_config=dict(shared_with_trainer=True, ),
),
critic=dict(
model_path=critic_model_path,
model_path=None,
model_type='critic',
use_flash_attn=False,
trainer_config=dict(
torch_dtype='auto',
torch_dtype=MODEL_DTYPE,
trainer_type='huggingface',
use_flash_attn=True,
gradient_checkpointing=False,
train_kwargs=dict(
micro_bsz=1,
lr=5e-6,
Expand All @@ -127,14 +132,14 @@
loss_type='per_seq',
),
parallel=dict(
data=dict(size=2, mode='deepspeed'),
data=dict(size=CRITIC_DP_SIZE, mode='deepspeed'),
tensor=dict(size=1, mode='1d'),
pipeline=dict(size=1, interleaved_overlap=False),
sequence=False,
),
deepspeed_config={
'zero_optimization': {
'stage': 2,
'stage': ZERO_STAGE,
'offload_param': {
'device': 'none'
},
Expand All @@ -152,20 +157,36 @@
'data_types': {
'grad_accum_dtype': 'fp32'
},
'train_micro_batch_size_per_gpu': 2,
'gradient_accumulation_steps': 16,
'train_batch_size': 64
}),
'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE,
'gradient_accumulation_steps': CRITIC_GRADIENT_ACC_STEP,
'train_batch_size': PROMPT_BATCH_SIZE,
},
),
),
reference=dict(
model_path='internlm/internlm2-chat-1_8b-sft',
model_type='reference',
trainer_config=dict(
torch_dtype=MODEL_DTYPE,
trainer_type='huggingface',
use_flash_attn=True,
parallel=dict(
data=dict(size=1, mode='ddp'),
tensor=dict(size=1, mode='1d'),
pipeline=dict(size=1, interleaved_overlap=False),
sequence=False,
),
),
),
reward=dict(
model_path=critic_model_path,
model_path=None,
model_type='reward',
use_flash_attn=False,
trainer_config=dict(
torch_dtype=MODEL_DTYPE,
trainer_type='huggingface',
torch_dtype='auto',
use_flash_attn=True,
parallel=dict(
data=dict(size=2, mode='ddp'),
data=dict(size=1, mode='ddp'),
tensor=dict(size=1, mode='1d'),
pipeline=dict(size=1, interleaved_overlap=False),
sequence=False,
Expand All @@ -175,14 +196,21 @@
)

dataset_config = {
'num_samples_each_epoch':
64,
'prompt_samples_each_epoch':
PROMPT_BATCH_SIZE,
'pretrain_samples_each_epoch':
PRETRAIN_BATCH_SIZE,
'max_seq_len':
1024,
MAX_PROMPT_LEN,
'random_seed':
1024,
'ppo_datas': [
# "sample_strategy": "in_data",
# "ratio_within_datasets": False,
'prompt_datasets': [
'Anthropic/hh-rlhf/helpful-base::1.0',
'Anthropic/hh-rlhf/harmless-base::0.5',
],
'pretrain_datasets': [
'Anthropic/hh-rlhf/helpful-base::1.0',
],
}
Loading

0 comments on commit f3334f0

Please sign in to comment.