Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update datasets #77

Merged
merged 1 commit into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion chatllms/configs/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ class DatasetAttr(object):
multi_turn: Optional[bool] = False

def __repr__(self) -> str:
rep = f'dataset_name: {self.dataset_name}, hf_hub_url: {self.hf_hub_url}, local_path: {self.local_path}, data_formate:{self.dataset_format} load_from_local: {self.load_from_local}, multi_turn: {self.multi_turn}'
rep = (f'dataset_name: {self.dataset_name} || '
f'hf_hub_url: {self.hf_hub_url} || '
f'local_path: {self.local_path} \n'
f'data_formate: {self.dataset_format} || '
f'load_from_local: {self.load_from_local} || '
f'multi_turn: {self.multi_turn}')
return rep

def __post_init__(self):
Expand Down Expand Up @@ -104,6 +109,11 @@ def init_for_training(self): # support mixing multiple datasets
if datasets_info[name]['local_path'] and os.path.exists(
datasets_info[name]['local_path']):
dataset_attr.load_from_local = True
else:
dataset_attr.load_from_local = False
raise Warning(
'You have set local_path for {} but it does not exist! Will load the data from {}'
.format(name, dataset_attr.hf_hub_url))

if 'columns' in datasets_info[name]:
dataset_attr.prompt_column = datasets_info[name][
Expand Down
8 changes: 5 additions & 3 deletions chatllms/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,12 +471,14 @@ def make_data_module(args):

for dataset_attr in args.datasets_list:
print('=' * 80)
print('DatasetAttr: {}...'.format(dataset_attr))
print('DatasetAttr: {}'.format(dataset_attr))

if dataset_attr.load_from_local:
dataset_path = dataset_attr.local_path
elif dataset_attr.hf_hub_url:
dataset_path = dataset_attr.hf_hub_url
else:
raise ValueError('Please set the dataset path or hf_hub_url.')

dataset = load_data(dataset_path,
eval_dataset_size=args.eval_dataset_size)
Expand All @@ -498,11 +500,11 @@ def make_data_module(args):
max_train_samples=args.max_train_samples,
)
if train_dataset:
print('loaded dataset:', dataset_attr.dataset_name,
print('loaded dataset:', dataset_attr.dataset_name, ' ',
'#train data size:', len(train_dataset))
train_datasets.append(train_dataset)
if eval_dataset:
print('loaded dataset:', dataset_attr.dataset_name,
print('loaded dataset:', dataset_attr.dataset_name, ' '
'#eval data size:', len(eval_dataset))
eval_datasets.append(eval_dataset)

Expand Down
2 changes: 0 additions & 2 deletions train_qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def main():
args = argparse.Namespace(**vars(model_args), **vars(data_args),
**vars(training_args), **vars(lora_args),
**vars(quant_args))

print(args.datasets_list)
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
if not os.path.exists(args.output_dir):
Expand Down