Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update datasets #84

Merged
merged 7 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ We can also tweak our hyperparameters:
```bash
python train_qlora.py \
--model_name_or_path ~/checkpoints/baichuan7b \
--dataset_name oasst1 \
--dataset_cfg ./data/alpaca_zh_pcyn.yaml \
--output_dir ./work_dir/oasst1-baichuan-7b \
--num_train_epochs 4 \
--per_device_train_batch_size 4 \
Expand Down
4 changes: 2 additions & 2 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ python train_qlora.py –learning_rate 0.0001 --model_name_or_path <path_or_name

我们还可以调整我们的超参数:

```python
```bash
python train_qlora.py \
--model_name_or_path ~/checkpoints/baichuan7b \
--dataset_name oasst1 \
--dataset_cfg ./data/alpaca_zh_pcyn.yaml \
--data_dir ~/prompt_datasets \
--load_from_local \
--output_dir ./work_dir/oasst1-baichuan-7b \
Expand Down
36 changes: 11 additions & 25 deletions chatllms/configs/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ class DatasetAttr(object):
hf_hub_url: Optional[str] = None
local_path: Optional[str] = None
dataset_format: Optional[str] = None
dataset_sha1: Optional[str] = None
load_from_local: bool = False
multi_turn: Optional[bool] = False

Expand All @@ -34,18 +33,11 @@ def __post_init__(self):

@dataclass
class DataArguments:
# 微调数据集是 alpaca
dataset_name: Optional[str] = field(
default='alpaca',
metadata={
'help': 'Which dataset to finetune on. See datamodule for options.'
})
# 数据集的本地路径,如果load_from_local为True,那么就从本地加载数据集
dataset_dir: str = field(
default=None,
dataset_cfg: Optional[str] = field(
default='./data/alpaca_zh.yaml',
metadata={
'help':
'where is dataset in local dir. See datamodule for options.'
'Path to dataset infos, please refer to `./data/README.md` to see how to prepare your datasets for training.'
})
instruction_template: str = field(
default='default',
Expand Down Expand Up @@ -82,19 +74,13 @@ class DataArguments:
)

def init_for_training(self): # support mixing multiple datasets
dataset_names = [ds.strip() for ds in self.dataset_name.split(',')]
this_dir = os.path.dirname(os.path.abspath(__file__))
datasets_info_path = os.path.join(this_dir, '../..', 'data',
'dataset_info.yaml')
with open(datasets_info_path, 'r') as f:
datasets_info = yaml.safe_load(f)

self.datasets_list: List[DatasetAttr] = []
for i, name in enumerate(dataset_names):
if name not in datasets_info:
raise ValueError('Undefined dataset {} in {}'.format(
name, datasets_info_path))

assert self.dataset_cfg is not None and os.path.exists(
self.dataset_cfg
), f'{self.dataset_cfg} does not exist!, please check the path.'
datasets_info = yaml.safe_load(open(self.dataset_cfg, 'r'))
self.dataset_names = list(datasets_info.keys())
self.dataset_attr_list: List[DatasetAttr] = []
for i, name in enumerate(self.dataset_names):
dataset_attr = DatasetAttr()
dataset_attr.dataset_name = name
dataset_attr.dataset_format = datasets_info[name].get(
Expand Down Expand Up @@ -126,4 +112,4 @@ def init_for_training(self): # support mixing multiple datasets
dataset_attr.history_column = datasets_info[name][
'columns'].get('history', None)

self.datasets_list.append(dataset_attr)
self.dataset_attr_list.append(dataset_attr)
2 changes: 1 addition & 1 deletion chatllms/configs/lora_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class LoraArguments:
lora_dropout: float = field(default=0.0,
metadata={'help': 'Lora dropout.'})
# 每个GPU上可使用的显存大小,以MB为单位。默认是A100高端版本的80GB
max_memory_MB: int = field(default=8000,
max_memory_MB: int = field(default=80000,
metadata={'help': 'Free memory per gpu.'})
lora_weight_path: str = ''
bias: str = 'none'
6 changes: 3 additions & 3 deletions chatllms/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,16 +460,16 @@ def make_data_module(args):
"""
train_datasets: List[Dataset] = []
eval_datasets: List[Dataset] = []
dataset_name_list = args.dataset_name.split(',')
dataset_name_list = args.dataset_names
print(f'Loading datasets: {dataset_name_list}')
mutliturn_lst = [
dataset_attr.multi_turn for dataset_attr in args.datasets_list
dataset_attr.multi_turn for dataset_attr in args.dataset_attr_list
]
assert mutliturn_lst.count(mutliturn_lst[0]) == len(
mutliturn_lst
), 'All datasets should be multi-turn or single-turn. As follwing we will concat all datasets, so they should be in the same format.'

for dataset_attr in args.datasets_list:
for dataset_attr in args.dataset_attr_list:
print('=' * 80)
print('DatasetAttr: {}'.format(dataset_attr))

Expand Down
65 changes: 33 additions & 32 deletions data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ We provide the following datasets for the experiments in this framework.

## Dataset formation

The `dataset_info.yaml` file contains the information of the datasets, main including the following fields.
The `dataset_info.yaml` file contains all the datasets can be used in the experiments. The following is the format of the datasets, main including the following fields.

```yaml
dataset_name:
Expand Down Expand Up @@ -77,37 +77,6 @@ alpaca:
multi_turn: False
```

### How to use in training scripts

After specifying the dataset information, you can run the following command to train the model. Just specify the `dataset_name` as one of the dataset name in `dataset_info.yaml`. If you want to use more than one dataset, please specify the `dataset_name` as str list with comma separated, e.g., `--dataset_name 'alpaca,dolly'.

```shell
python train.py \
--model_name_or_path facebook/opt-125m \
--dataset_name alpaca \
--output_dir work_dir/full-finetune \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "steps" \
--save_strategy "steps" \
--eval_steps 1000 \
--save_steps 1000 \
--save_total_limit 5 \
--logging_steps 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--optim "adamw_torch" \
--lr_scheduler_type "cosine" \
--gradient_checkpointing True \
--model_max_length 128 \
--do_train \
--do_eval
```


## Custom datasets

If you are using a custom dataset, please provide your dataset definition in `dataset_info.yaml`.
Expand Down Expand Up @@ -257,3 +226,35 @@ def main():
if __name__ == '__main__':
main()
```

### How to use in training scripts

In the `data/` directory, we provide some dataset info dict used in the experiments. The following script shows how to use the `alpaca_zh.yaml` dataset info dict.

```shell
python train.py \
--model_name_or_path facebook/opt-125m \
--dataset_cfg alpaca_zh.yaml \
--output_dir work_dir/full-finetune \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "steps" \
--save_strategy "steps" \
--eval_steps 1000 \
--save_steps 1000 \
--save_total_limit 5 \
--logging_steps 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--optim "adamw_torch" \
--lr_scheduler_type "cosine" \
--gradient_checkpointing True \
--model_max_length 128 \
--do_train \
--do_eval
```

You can use the `alpaca_zh.yaml` directly or create a custom dataset config and then set the `dataset_cfg` argument to `your_dataset_info.yaml`.
42 changes: 42 additions & 0 deletions data/alpaca_zh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# The dataset_info.yaml file contains the information of the datasets used in the experiments.
coig:
hf_hub_url: BAAI/COIG
local_path: /home/robin/prompt_data/COIG/train_alpaca.json
dataset_format: alpaca
multi_turn: False

cvalues_comparison_train:
hf_hub_url: ''
local_path: /home/robin/prompt_data/CValues-Comparison/train_alpaca.json
dataset_format: alpaca
multi_turn: False

cvalues_comparison_test:
hf_hub_url: ''
local_path: /home/robin/prompt_data/CValues-Comparison/test_alpaca.json
dataset_format: alpaca
multi_turn: False

olcc:
hf_hub_url: ''
local_path: /home/robin/prompt_data/olcc/olcc_alpaca.json
dataset_format: alpaca
multi_turn: False

100PoisonMpts:
hf_hub_url: 'damo/100PoisonMpts'
local_path: /home/robin/prompt_data/100PoisonMpts/train_alpaca.json
dataset_format: alpaca
multi_turn: False

safety_prompt_part1:
hf_hub_url: ''
local_path: /home/robin/prompt_data/Safety-Prompts/attack_scenarios_alpaca.json
dataset_format: alpaca
multi_turn: False

safety_prompt_part2:
hf_hub_url: ''
local_path: /home/robin/prompt_data/Safety-Prompts/safety_scenarios_alpaca.json
dataset_format: alpaca
multi_turn: False
42 changes: 42 additions & 0 deletions data/alpaca_zh_pcyn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# The dataset_info.yaml file contains the information of the datasets used in the experiments.
coig:
hf_hub_url: BAAI/COIG
local_path: /userhome/jianzhnie/prompt_data/COIG/train_alpaca.json
dataset_format: alpaca
multi_turn: False

cvalues_comparison_train:
hf_hub_url: ''
local_path: /userhome/jianzhnie/prompt_data/CValues-Comparison/train_alpaca.json
dataset_format: alpaca
multi_turn: False

cvalues_comparison_test:
hf_hub_url: ''
local_path: /userhome/jianzhnie/prompt_data/CValues-Comparison/test_alpaca.json
dataset_format: alpaca
multi_turn: False

olcc:
hf_hub_url: ''
local_path: /userhome/jianzhnie/prompt_data/olcc/olcc_alpaca.json
dataset_format: alpaca
multi_turn: False

100PoisonMpts:
hf_hub_url: ''
local_path: /userhome/jianzhnie/prompt_data/100PoisonMpts/train_alpaca.json
dataset_format: alpaca
multi_turn: False

safety_prompt_part1:
hf_hub_url: ''
local_path: /userhome/jianzhnie/prompt_data/Safety-Prompts/attack_scenarios_alpaca.json
dataset_format: alpaca
multi_turn: False

safety_prompt_part2:
hf_hub_url: ''
local_path: /userhome/jianzhnie/prompt_data/Safety-Prompts/safety_scenarios_alpaca.json
dataset_format: alpaca
multi_turn: False
40 changes: 40 additions & 0 deletions data/belle_group.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
belle_0.5m:
hf_hub_url: BelleGroup/train_0.5M_CN
local_path: ''
dataset_format: alpaca
multi_turn: False

belle_1m:
hf_hub_url: BelleGroup/train_1M_CN
local_path: ''
dataset_format: alpaca
multi_turn: False

belle_2m:
hf_hub_url: BelleGroup/train_2M_CN
local_path: ''
dataset_format: alpaca
multi_turn: False

belle_dialog:
hf_hub_url: BelleGroup/generated_chat_0.4M
local_path: ''
dataset_format: belle_dialog
multi_turn: False

belle_math:
hf_hub_url: BelleGroup/school_math_0.25M
local_path: ''
dataset_format: alpaca
multi_turn: False

belle_multiturn:
hf_hub_url: BelleGroup/multi_turn_0.5M
local_path: ''
dataset_format: belle_multiturn
multi_turn: True
columns:
prompt: instruction
query: ''
response: output
history: history
2 changes: 1 addition & 1 deletion data/dataset_info.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ olcc:
multi_turn: False

100PoisonMpts:
hf_hub_url: ''
hf_hub_url: 'damo/100PoisonMpts'
local_path: /home/robin/prompt_data/100PoisonMpts/train.jsonl
dataset_format: 100PoisonMpts
multi_turn: False
Expand Down
42 changes: 42 additions & 0 deletions data/vicuna_zh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# The dataset_info.yaml file contains the information of the datasets used in the experiments.
coig:
hf_hub_url: BAAI/COIG
local_path: /home/robin/prompt_data/COIG/train_vicuna.json
dataset_format: sharegpt
multi_turn: True

cvalues_comparison_train:
hf_hub_url: ''
local_path: /home/robin/prompt_data/CValues-Comparison/train_vicuna.json
dataset_format: sharegpt
multi_turn: True

cvalues_comparison_test:
hf_hub_url: ''
local_path: /home/robin/prompt_data/CValues-Comparison/test_vicuna.json
dataset_format: sharegpt
multi_turn: True

olcc:
hf_hub_url: ''
local_path: /home/robin/prompt_data/olcc/olcc_vicuna.json
dataset_format: sharegpt
multi_turn: True

100PoisonMpts:
hf_hub_url: ''
local_path: /home/robin/prompt_data/100PoisonMpts/train_vicuna.json
dataset_format: sharegpt
multi_turn: True

safety_prompt_part1:
hf_hub_url: ''
local_path: /home/robin/prompt_data/Safety-Prompts/attack_scenarios_vicuna.json
dataset_format: sharegpt
multi_turn: True

safety_prompt_part2:
hf_hub_url: ''
local_path: /home/robin/prompt_data/Safety-Prompts/safety_scenarios_vicuna.json
dataset_format: sharegpt
multi_turn: True
Loading