Skip to content

Commit

Permalink
Merge pull request #76 from jianzhnie/dev
Browse files Browse the repository at this point in the history
update datasets
  • Loading branch information
jianzhnie committed Jul 27, 2023
2 parents 5402df2 + 7f83fd8 commit c46cd52
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 48 deletions.
6 changes: 5 additions & 1 deletion chatllms/configs/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ class DatasetAttr(object):
dataset_name: Optional[str] = None
hf_hub_url: Optional[str] = None
local_path: Optional[str] = None
dataset_format: Optional[str] = None
dataset_sha1: Optional[str] = None
load_from_local: bool = False
multi_turn: Optional[bool] = False

def __repr__(self) -> str:
return self.dataset_name
rep = f'dataset_name: {self.dataset_name}, hf_hub_url: {self.hf_hub_url}, local_path: {self.local_path}, data_formate:{self.dataset_format} load_from_local: {self.load_from_local}, multi_turn: {self.multi_turn}'
return rep

def __post_init__(self):
self.prompt_column = 'instruction'
Expand Down Expand Up @@ -90,6 +92,8 @@ def init_for_training(self): # support mixing multiple datasets

dataset_attr = DatasetAttr()
dataset_attr.dataset_name = name
dataset_attr.dataset_format = datasets_info[name].get(
'dataset_format', None)
dataset_attr.hf_hub_url = datasets_info[name].get(
'hf_hub_url', None)
dataset_attr.local_path = datasets_info[name].get(
Expand Down
2 changes: 1 addition & 1 deletion chatllms/configs/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class TrainingArguments(TrainingArguments):
'Group sequences into batches with same length. Saves memory and speeds up training considerably.'
})
model_max_length: int = field(
default=2048,
default=1024,
metadata={
'help':
'Maximum sequence length. Sequences will be right padded (and possibly truncated).'
Expand Down
49 changes: 30 additions & 19 deletions chatllms/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,16 @@ def load_data(
"""
if not os.path.exists(dataset_path):
# Download dataset from HuggingFace Datasets
print(
f'Lodding dataset from huggingface, please ref to https://huggingface.co/datasets/{dataset_path}'
)
dataset = load_dataset(dataset_path,
cache_dir='~/.cache/huggingface/datasets')
return dataset
else:
# Load dataset from local file
try:
print(f'Lodding dataset from local path: {dataset_path}')
dataset = local_dataset(dataset_path, eval_dataset_size)
return dataset
except:
Expand All @@ -257,6 +261,7 @@ def load_data(
def formate_instruction_dataset(
dataset: Dataset,
dataset_name: str,
dataset_format: str,
instruction_template: str = 'default') -> Optional[Dict[str, Dataset]]:
"""
Formats a given dataset based on its name and format.
Expand All @@ -270,6 +275,7 @@ def formate_instruction_dataset(
Args:
dataset: A dataset object to be formatted.
dataset_name: A string representing the name of the dataset to be formatted.
dataset_format: A string representing the name of the dataset format to be used.
instruction_template: A string representing the name of the prompt template to be used.
Returns:
Expand Down Expand Up @@ -326,36 +332,33 @@ def _remove_unused_columns(dataset):
])
return dataset

print('formate the dataset to the format we need.')
if dataset_name == 'dolly-15k':
# Format dataset
print(f'The {dataset_name} using {dataset_format} dataset format.')
if dataset_format == 'alpaca':
print('By default, We support the Alpaca dataset format.')
elif dataset_format == 'dolly':
dataset = _format_dolly15k(dataset)
elif dataset_name == 'chip2':
elif dataset_format == 'chip2':
dataset = _format_chip2(dataset)
elif dataset_name == 'self-instruct':
elif dataset_format == 'self-instruct':
dataset = _format_self_instruct(dataset)
elif dataset_name == 'hh-rlhf':
elif dataset_format == 'hh-rlhf':
dataset = _format_hh_rlhf(dataset)
elif dataset_name == 'oasst1':
elif dataset_format == 'oasst1':
dataset = _format_oasst1(dataset)
elif dataset_name == '100PoisonMpts':
elif dataset_format == '100PoisonMpts':
dataset = _format_100Poison(dataset)
else:
print(
f'For dataset {dataset_name} with alpaca dataset formation, we do not need additional processing'
raise NotImplementedError(
f'Unsupported dataset format: {dataset_format}, Please add the formate function in data_utils.py'
)
pass

# encode_instruction_example
print(
f'Encoding the instruction example refer to : {instruction_template}')
print(f'Applying instruction template: {instruction_template}')
if instruction_template == 'alpaca':
print('Using alpaca prompt template: ', {instruction_template})
dataset = dataset.map(extract_alpaca_prompt_dataset)
elif instruction_template == 'random':
print('Using random prompt template: ', {instruction_template})
dataset = dataset.map(extract_random_prompt_dataset)
else:
print('Using default prompt template: ', {instruction_template})
dataset = dataset.map(extract_default_prompt_dataset)

# Remove unused columns.
Expand Down Expand Up @@ -402,13 +405,16 @@ def split_train_eval(
else:
# Split train dataset in train and validation according to `eval_dataset_size`
print(
'Splitting train dataset in train and validation according to `eval_dataset_size`'
f'Splitting the dataset into train and validation according to `eval_dataset_size`: {eval_dataset_size}'
)
dataset = dataset['train'].train_test_split(
test_size=eval_dataset_size, shuffle=True, seed=42)
eval_dataset = dataset['test']

# Reduce evaluation dataset size (if specified)
print(
f'You have set the max_eval_samples: {max_eval_samples}, will do sampling ...'
)
if max_eval_samples is not None and len(
eval_dataset) > max_eval_samples:
eval_dataset = eval_dataset.select(np.arange(max_eval_samples))
Expand All @@ -418,6 +424,9 @@ def split_train_eval(
train_dataset = dataset['train']

# Reduce training dataset size (if specified)
print(
f'You have set the max_train_samples: {max_train_samples}, will do sampling ...'
)
if max_train_samples is not None and len(
train_dataset) > max_train_samples:
train_dataset = train_dataset.select(np.arange(max_train_samples))
Expand Down Expand Up @@ -461,12 +470,13 @@ def make_data_module(args):
), 'All datasets should be multi-turn or single-turn. As follwing we will concat all datasets, so they should be in the same format.'

for dataset_attr in args.datasets_list:
print('Loading dataset {}...'.format(dataset_attr))
print('=' * 80)
print('DatasetAttr: {}...'.format(dataset_attr))

if dataset_attr.load_from_local:
dataset_path = dataset_attr.local_path
elif dataset_attr.hf_hub_url:
dataset_path = dataset_attr.dataset_name
dataset_path = dataset_attr.hf_hub_url

dataset = load_data(dataset_path,
eval_dataset_size=args.eval_dataset_size)
Expand All @@ -475,6 +485,7 @@ def make_data_module(args):
dataset = formate_instruction_dataset(
dataset,
dataset_name=dataset_attr.dataset_name,
dataset_format=dataset_attr.dataset_format,
instruction_template=args.instruction_template,
)

Expand Down
183 changes: 170 additions & 13 deletions data/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

## How to use the data
# How to use the data

### Datasets Supported by the Framework
## Datasets Supported by the Framework

We provide the following datasets for the experiments in this framework.

Expand All @@ -25,26 +25,183 @@ We provide the following datasets for the experiments in this framework.
- [timdettmers/openassistant-guanaco](https://huggingface.co/datasets/timdettmers/openassistant-guanaco)
- [Evol-Instruct](https://huggingface.co/datasets/victor123/evol_instruct_70k)

### Dataset formation
## Dataset formation

The `dataset_info.yaml` file contains the information of the datasets. By defaullt, the framework will load the datasets from the HuggingFace hub. If you want to use the datasets from local files, please specify the `local_path` in the `dataset_info.yaml` file. For example, if you want to use the Alpaca dataset from local files, please specify the following in `dataset_info.yaml`.
The `dataset_info.yaml` file contains the information of the datasets, main including the following fields.

```yaml
dataset_name:
hf_hub_url: # "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
local_path: # "the name of the dataset file in the this directory. (required if above are not specified)",
dataset_format: # "the format of the dataset. (required), e.g., alpaca, dolly, etc.",
multi_turn: # "whether the dataset is multi-turn. (default: False)"
```

For example, the following is the dataset information of the Stanford Alpaca dataset.

```yaml
alpaca:
hf_hub_url: tatsu-lab/alpaca
local_path: tatsu-lab/alpaca/alpaca.json
local_path:
dataset_format: alpaca
multi_turn: False
```
While training, the framework will load the dataset from the HuggingFace hub. If you want to load the dataset from local files, please specify the `local_path` field.

### Custom datasets
```yaml
alpaca:
hf_hub_url: tatsu-lab/alpaca
local_path: path/to/alpaca.json
dataset_format: alpaca
multi_turn: False
```

If you are using a custom dataset, please provide your dataset definition in the following format in `dataset_info.yaml`.
## Custom datasets

```yaml
dataset_name:
hf_hub_url: # "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
local_path: # "the name of the dataset file in the this directory. (required if above are not specified)",
multi_turn: # "whether the dataset is multi-turn. (default: False)"
If you are using a custom dataset, please provide your dataset definition in `dataset_info.yaml`.

### hf_hub_url and local_path

By defaullt, the framework will load the datasets from the HuggingFace hub. If you want to use the datasets from local files, please specify the `local_path` field.

### dataset_format

As for the dataset_format field, which is used to specify the format of the dataset, will be used to determine the dataset processing method. Currently, we support the following dataset formats.

- `alpaca`: Alpaca dataset
- `dolly`: Dolly dataset
- `gpt4`: GPT-4 generated dataset
- `alpaca_cot`: Alpaca CoT dataset
- `oasst1`: OpenAssistant/oasst1 dataset
- `sharegpt`: Multi-turn ShareGPT dataset

If your dataset is not in the above format, there are two ways to use it.

- The first way, implement the `format_dataset` function in [data_utils](./chatllms/data/data_utils.py).

For example, the following is the `_format_dolly15k` function for the Dolly dataset.

```python
def _format_dolly15k(dataset: Dataset) -> Dataset:
"""Format Dolly-15k dataset."""
dataset = dataset.rename_column('context', 'input')
dataset = dataset.rename_column('response', 'output')
return dataset
```

- The second way, convert your dataset to the above format.

For example, the flowing code is used to convert the [databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k) to the Alpaca format.

```python
import json
def convert_dolly_alpaca(in_file, out_file):
with open(in_file, 'r') as file:
contents = json.load(file)
new_content = []
for i, content in enumerate(contents):
new_content.append({
'instruction': content['instruction'],
'input': content['text'],
'output': content['text'],
})

print(f'#out: {len(new_content)}')
with open(out_file, 'w') as file:
json.dump(new_content, file, indent=2, ensure_ascii=False)
```

where the `prompt` and `response` columns should contain non-empty values. The `query` column will be concatenated with the `prompt` column and used as input for the model. The `history` column should contain a list where each element is a string tuple representing a query-response pair.
### multi_turn

If your dataset is multi-turn, pleas set the `multi_turn: True` in `dataset_info.yaml`. The framework will automatically process the multi-turn dataset.

Flowing is an example to show the format of multi-turn dataset.

```json
[
{
"id": "identity_0",
"conversations": [
{
"from": "human",
"value": "Who are you?"
},
{
"from": "gpt",
"value": "I am Vicuna, a language model trained by researchers from Large Model Systems Organization (LMSYS)."
},
{
"from": "human",
"value": "What can you do?"
},
{
"from": "gpt",
"value": "I can chat with you."
}
]
},
{
"id": "identity_1",
"conversations": [
{
"from": "human",
"value": "Who are you?"
},
{
"from": "gpt",
"value": "My name is Vicuna, and I'm a language model developed by Large Model Systems Organization (LMSYS)."
}
]
},
]
```

For now, we only support the multi-turn dataset in the above format. If your dataset is not in the above format, please convert it. We also provide the following code to convert the Dolly dataset to the above format. You can find the code in [convert_alpaca](`./chatllms/data/utils/convert_alpaca.py`).

```python
import argparse
import json
from typing import Any, Dict, List

from datasets import load_dataset

def convert_dolly_vicuna(raw_data: List[Dict[str, Any]]):
collect_data = []
for i, content in enumerate(raw_data):
if len(content['context'].strip()) > 1:
q, a = content['instruction'] + '\nInput:\n' + content[
'context'], content['response']
else:
q, a = content['instruction'], content['response']

collect_data.append({
'id':
f'alpaca_{i}',
'conversations': [
{
'from': 'human',
'value': q
},
{
'from': 'gpt',
'value': a
},
],
})
print(f'Original: {len(raw_data)}, Converted: {len(collect_data)}')
return collect_data

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--in-file', type=str)
parser.add_argument('--out-file', type=str)
args = parser.parse_args()

raw_data = load_dataset('json', data_files=args.in_file)['train']
new_data = convert_dolly_vicuna(raw_data)
json_dump(new_data, args.out_file)


if __name__ == '__main__':
main()
```
Loading

0 comments on commit c46cd52

Please sign in to comment.