Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supervised Fine-tuning for HugginFace pretrained weight. #318

Merged
merged 5 commits into from
Jan 26, 2024

Conversation

inkcherry
Copy link

@inkcherry inkcherry commented Dec 18, 2023

1 noticed many users employing Megatron for fine-tuning Huggingface's pretrained weights, aiming for improved large-scale model performance or convergence. Currently, the Megatron-LM supports weight conversion https://github.com/NVIDIA/Megatron-LM/blob/main/docs/llama2.md , but this isn't available for non-cuda devices. Here, we add weight conversion from HF llama to Megatron-Deepspeed.
(noticed that another earlier PR was also related to weight conversion, but I failed to use it. It seems that the format is Megatron LM. Unfortunately, we were unable to contact the author #246)

2 Additionally, we add a fine-tuning script. Include the SFT process, utilizing an HF tokenizer, and a prompt dataset (refer https://github.com/tatsu-lab/stanford_alpaca) along with a dataloader.
we use repeating dataloader to address the issue of dataset inadequacy causing StopIteration problems.

Through steps 1 and 2, it's now feasible to conduct fine-tuning for the Alpaca-finetune( https://github.com/tatsu-lab/stanford_alpaca) task using Megatron-Deepspeed.

Furthermore, we've made some others changes.
1 fix an RMSnorm fallback path issue,
2 deals with invalid HF tokenizer arguments.
3 skips the 'cat' in ROPE operation when the 'cat dim' shape is 0.

clean up

update

update

update

update

update

update

update

arg fix

update

clean up

update

update

update

refine weight converter

don't cat when dim=0

format

update

update

update
@LLLLLgq
Copy link

LLLLLgq commented Jan 2, 2024

There should be a more condition when processing new_w for padded tokens in _embedding_refactor function, due to absence of padded token.

        new_w = torch.zeros((per_partition_vocab_size, hf_w.shape[1]), dtype=hf_w.dtype)
        new_w[:real_partition_vocab_size, :] = hf_w[start_index:end_index, :]
        if self.tp_rank == self.tp_size - 1 and self.more_padded > 0:
            new_w[-self.more_padded:] = hf_w[:self.token_vocab].mean(dim=0, keepdim=True)

@inkcherry
Copy link
Author

inkcherry commented Jan 4, 2024

Thank you for pointing that out! @LLLLLgq ,Would it work for your case after this change?
I've added a script example that you could refer to as well.

@inkcherry
Copy link
Author

inkcherry commented Jan 23, 2024

We have also plotted the loss curves for transformer Zero2 training and Megatron-Deepspeed 3D parallel training. The loss values are printed at each step, with approximately 400 steps per epoch. It can be observed that the initial loss values for both models are around 1.6 (typically, megatron-ds pretraining loss is usually above 10.0, this demonstrates the effectiveness of weight conversion.). Throughout the training process, there are some turning points where the loss values for both Transformer models remain consistent. The final loss values for Transformer models are also in line with megatron-ds. cc @kefeiyao

  • transformer zero2 loss curve.
    image

  • megatron-deepspeed 3D training loss curve:
    image

@zdaiot
Copy link

zdaiot commented Jan 23, 2024

@inkcherry 您好,请问您有计划支持llama2么?

@conglongli conglongli self-assigned this Jan 26, 2024
Copy link

@conglongli conglongli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@inkcherry Thank you for the contribution. I have tested and LGTM. The conflict in megatron/model/gpt_model.py is due to I merged a newer PR that fixed the same issue #341. If you could help resolve it, I will then merge this PR. Thanks.

@conglongli
Copy link

@inkcherry nevermind, it's a simple conflict and I just resolved it. Merging now.

@conglongli conglongli merged commit 11f2d93 into microsoft:main Jan 26, 2024
1 check passed
@inkcherry
Copy link
Author

@conglongli Thanks for the help!
@zdaiot Probably not recently, but I think llama2 13B should be okay (because 'kv heads' and 'heads' are consistent, I haven't verified). It might be necessary to pay attention to llama2 70B here.

@cangshuli
Copy link

cangshuli commented Jan 30, 2024

Can you support GPTModel for llama2-7b with no pipeline parallel, for model convert with pp=1, tp=1, dp=8

hf2megads_weight_converter.py
if args.deepspeed and not args.no_pipeline_parallel:
model = GPTModelPipe(config, num_tokentypes=0, parallel_output=True)
else:
raise NotImplementedError("Not implemented")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants