Skip to content

Commit

Permalink
feat: use huggingface hub dataset directly for embedding fine-tune
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan committed Jul 8, 2024
1 parent 3bc79e7 commit 4d9268f
Show file tree
Hide file tree
Showing 25 changed files with 389 additions and 117 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ If you are interested in contributing to open-retrievals,
Once you finish implementing a feature a bug-fix, please send a Pull Request to https://github.com/LongxingTan/open-retrievals


## Developing TFTS
## Developing open-retrievals

To develop tfts on your machine, here are some tips:

Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
![structure](./docs/source/_static/structure.png)

**Open-retrievals** improve and unify text embedding, retrieval, reranking and RAG.
- Embeddings fine-tuned through point-wise, pairwise, listwise, contrastive learning, and LLM.
- Embedding fine-tuned through point-wise, pairwise, listwise, contrastive learning, and LLM.
- Reranking fine-tuned with Cross Encoder, ColBERT, and LLM.
- Easily build enhanced RAG, integrated with Transformers, Langchain, and LlamaIndex.

Expand Down Expand Up @@ -104,7 +104,7 @@ model.build_index(sentences, index_path=index_path)

query_embed = model.encode("He plays guitar.")
matcher = AutoModelForRetrieval()
dists, indices = matcher.similarity_search(query_embed, index_path=index_path)
dists, indices = matcher.search(query_embed, index_path=index_path)
print(indices)
```

Expand Down Expand Up @@ -230,7 +230,7 @@ trainer.scheduler = scheduler
trainer.train()
```

**Rerank Fine-tuning**
**Reranking Fine-tuning**

```python
from transformers import AutoTokenizer, TrainingArguments, get_cosine_schedule_with_warmup, AdamW
Expand Down Expand Up @@ -270,6 +270,7 @@ trainer.train()

## Reference & Acknowledge
- [sentence-transformers](https://github.com/UKPLab/sentence-transformers)
- [Dense](https://github.com/luyug/Dense)
- [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding)
- [uniem](https://github.com/wangyuxinwhy/uniem)
- [BCEmbedding](https://github.com/netease-youdao/BCEmbedding)
5 changes: 3 additions & 2 deletions README_ja-JP.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ model.build_index(sentences, index_path=index_path)

query_embed = model.encode("He plays guitar.")
matcher = AutoModelForRetrieval()
dists, indices = matcher.similarity_search(query_embed, index_path=index_path)
dists, indices = matcher.search(query_embed, index_path=index_path)
print(indices)
```

Expand Down Expand Up @@ -283,12 +283,13 @@ query_embeddings = model.encode(query_texts, convert_to_tensor=True)
document_embeddings = model.encode(document_texts, convert_to_tensor=True)

matcher = AutoModelForRetrieval(method='cosine')
dists, indices = matcher.similarity_search(query_embeddings, document_embeddings, top_k=1)
dists, indices = matcher.search(query_embeddings, document_embeddings, top_k=1)
```


## 参考資料と謝辞
- [sentence-transformers](https://github.com/UKPLab/sentence-transformers)
- [Dense](https://github.com/luyug/Dense)
- [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding)
- [uniem](https://github.com/wangyuxinwhy/uniem)
- [BCEmbedding](https://github.com/netease-youdao/BCEmbedding)
11 changes: 6 additions & 5 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@
- 支持全套重排微调,cross encoder、ColBERT、LLM
- 支持定制化RAG框架,支持在Transformers、Langchain、LlamaIndex中便捷使用微调后的模型

| 实验 | 模型 | 尺寸| 原分数 | 微调分数 | Demo代码 |
|---------------------|-------------------------|----|-------|-----------|-------------------------------------------------------------------------------------------------------------------------------------|
| **向量**pairwise微调 | bge-base-zh-v1.5 | - | 0.657 | **0.703** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/17KXe2lnNRID-HiVvMtzQnONiO74oGs91?usp=sharing) |
| **向量**大模型LoRA微调 | Qwen2-1.5B-Instruct | - | 0.546 | **0.694** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jj1kBQWFcuQ3a7P9ttnl1hgX7H8WA_Za?usp=sharing) |
| 实验 | 模型 | 尺寸| 原分数 | 微调分数 | Demo代码 |
|-----------------------|-------------------------|----|-------|-----------|-------------------------------------------------------------------------------------------------------------------------------------|
| pairwise微调**向量** | bge-base-zh-v1.5 | - | 0.657 | **0.703** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/17KXe2lnNRID-HiVvMtzQnONiO74oGs91?usp=sharing) |
| 大模型LoRA微调**向量** | Qwen2-1.5B-Instruct | - | 0.546 | **0.694** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jj1kBQWFcuQ3a7P9ttnl1hgX7H8WA_Za?usp=sharing) |
| cross encoder**重排** | bge-reranker-base | - | 0.666 | **0.706** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QvbUkZtG56SXomGYidwI4RQzwODQrWNm?usp=sharing) |
| colbert**重排** | chinese-roberta-wwm-ext | - | 0.643 | **0.687** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QVtqhQ080ZMltXoJyODMmvEQYI6oo5kO?usp=sharing) |
| LLM**重排** | Qwen2-1.5B-Instruct | - | 0.531 | **0.699** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1fzq1iV7-f8hNKFnjMmpVhVxadqPb9IXk?usp=sharing) |
Expand Down Expand Up @@ -104,7 +104,7 @@ model.build_index(sentences, index_path=index_path)

query_embed = model.encode("He plays guitar.")
matcher = AutoModelForRetrieval()
dists, indices = matcher.similarity_search(query_embed, index_path=index_path)
dists, indices = matcher.search(query_embed, index_path=index_path)
print(indices)
```

Expand Down Expand Up @@ -326,6 +326,7 @@ torchrun --nproc_per_node 1 \

## 参考与致谢
- [sentence-transformers](https://github.com/UKPLab/sentence-transformers)
- [Dense](https://github.com/luyug/Dense)
- [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding)
- [uniem](https://github.com/wangyuxinwhy/uniem)
- [BCEmbedding](https://github.com/netease-youdao/BCEmbedding)
4 changes: 4 additions & 0 deletions docs/source/embed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ Enhance the performance
* maxsim (multi vector)
* Matryoshka

tuning the important parameters:

* temperature


Hard mining
~~~~~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion docs/source/quick-start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Save the document embedding offline.
query_embed = model.encode("He plays guitar.")
matcher = AutoModelForRetrieval()
dists, indices = matcher.similarity_search(query_embed, index_path=index_path)
dists, indices = matcher.search(query_embed, index_path=index_path)
print(indices)
Expand Down
12 changes: 6 additions & 6 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ torchrun --nproc_per_node 1 \
--overwrite_output_dir \
--model_name_or_path $MODEL_NAME \
--do_train \
--train_data $TRAIN_DATA \
--data_name_or_path $TRAIN_DATA \
--positive_key positive \
--negative_key negative \
--learning_rate 3e-5 \
Expand Down Expand Up @@ -66,7 +66,7 @@ torchrun --nproc_per_node 1 \
--model_name_or_path $MODEL_NAME \
--pooling_method last \
--do_train \
--train_data $TRAIN_DATA \
--data_name_or_path $TRAIN_DATA \
--positive_key positive \
--negative_key negative \
--use_lora True \
Expand Down Expand Up @@ -103,7 +103,7 @@ python -m retrievals.pipelines.embed \
--do_encode \
--fp16 \
--per_device_eval_batch_size 256 \
--train_data $QUERY \
--data_name_or_path $QUERY \
--is_query true
```

Expand All @@ -124,7 +124,7 @@ torchrun --nproc_per_node 1 \
--model_name_or_path $MODEL_NAME \
--model_type cross-encoder \
--do_train \
--train_data $TRAIN_DATA \
--data_name_or_path $TRAIN_DATA \
--positive_key positive \
--negative_key negative \
--learning_rate 2e-5 \
Expand Down Expand Up @@ -152,7 +152,7 @@ torchrun --nproc_per_node 1 \
--tokenizer_name $MODEL_NAME \
--model_type colbert \
--do_train \
--train_data $TRAIN_DATA \
--data_name_or_path $TRAIN_DATA \
--positive_key positive \
--negative_key negative \
--learning_rate 1e-4 \
Expand Down Expand Up @@ -185,7 +185,7 @@ torchrun --nproc_per_node 1 \
--model_type llm \
--causal_lm True \
--use_lora True \
--train_data $TRAIN_DATA \
--data_name_or_path $TRAIN_DATA \
--task_prompt "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." \
--query_instruction "A: " \
--document_instruction 'B: ' \
Expand Down
18 changes: 10 additions & 8 deletions examples/embedding_llm_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ class ModelArguments:

@dataclass
class DataArguments:
train_data: str = field(default="intfloat/personalized_passkey_retrieval", metadata={"help": "Path to train data"})
data_name_or_path: str = field(
default="intfloat/personalized_passkey_retrieval", metadata={"help": "Path to train data"}
)
train_group_size: int = field(default=8)
query_max_length: int = field(
default=32,
Expand All @@ -92,8 +94,8 @@ class DataArguments:
document_instruction: str = field(default=None, metadata={"help": "instruction for document"})

def __post_init__(self):
if not os.path.exists(self.train_data):
raise FileNotFoundError(f"cannot find file: {self.train_data}, please set a true path")
if not os.path.exists(self.data_name_or_path):
raise FileNotFoundError(f"cannot find file: {self.data_name_or_path}, please set a true path")


@dataclass
Expand Down Expand Up @@ -131,12 +133,12 @@ def __init__(self, args: DataArguments, tokenizer):
self.args = args
self.tokenizer = tokenizer

if os.path.isdir(args.train_data):
if os.path.isdir(args.data_name_or_path):
train_datasets = []
for file in os.listdir(args.train_data):
for file in os.listdir(args.data_name_or_path):
temp_dataset = datasets.load_dataset(
"json",
data_files=os.path.join(args.train_data, file),
data_files=os.path.join(args.data_name_or_path, file),
split="train",
)
if len(temp_dataset) > args.max_example_num_per_dataset:
Expand All @@ -149,8 +151,8 @@ def __init__(self, args: DataArguments, tokenizer):
train_datasets.append(temp_dataset)
self.dataset = datasets.concatenate_datasets(train_datasets)
else:
# self.dataset = datasets.load_dataset("json", data_files=args.train_data, split="train")
self.dataset = datasets.load_dataset(args.train_data)
# self.dataset = datasets.load_dataset("json", data_files=args.data_name_or_path, split="train")
self.dataset = datasets.load_dataset(args.data_name_or_path)

def __len__(self):
return len(self.dataset)
Expand Down
12 changes: 9 additions & 3 deletions examples/eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
**Prerequisites**
```shell
pip install datasets mteb[beir]
pip install C_MTEB
pip install open-retrievals
pip install open-retrievals[eval]
```


Expand All @@ -16,7 +15,7 @@ from retrievals import AutoModelForEmbedding

class AutoModelForEmbeddingEval(AutoModelForEmbedding):
def __init__(self, **kwargs):
super(AutoModelForEmbeddingEval, self).__init__()
super(AutoModelForEmbeddingEval, self).__init__(**kwargs)

def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
"""For MTEB eval
Expand All @@ -40,3 +39,10 @@ class AutoModelForEmbeddingEval(AutoModelForEmbedding):
input_texts = corpus
return self.encode_from_text(input_texts, batch_size=4)
```


## Reference

- https://github.com/beir-cellar/beir
- https://github.com/AmenRa/ranx
- [MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard)
145 changes: 145 additions & 0 deletions examples/rerank_llm_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from dataclasses import dataclass, field
from typing import Optional

from transformers import (
AdamW,
AutoTokenizer,
HfArgumentParser,
TrainingArguments,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
)

from retrievals import (
AutoModelForRanking,
LLMRerankCollator,
RerankDataset,
RerankTrainer,
RetrievalDataset,
)
from retrievals.losses import TokenLoss


@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""

model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
causal_lm: bool = field(default=False, metadata={'help': "Whether the model is a causal lm or not"})


@dataclass
class DataArguments:
data_name_or_path: str = field(default=None, metadata={"help": "Path to corpus"})
train_group_size: int = field(default=8)
unfold_each_positive: bool = field(default=False)
max_length: int = field(
default=512,
metadata={
"help": "The maximum total input sequence length after tokenization for input text. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
query_key: str = field(default=None)
positive_key: str = field(default=None)
negative_key: str = field(default=None)

query_instruction: str = field(default=None, metadata={"help": "instruction for query"})
document_instruction: str = field(default=None, metadata={"help": "instruction for document"})
task_prompt: str = field(
default=(
"Given a query A and a passage B, determine whether the passage contains an answer "
"to the query by providing a prediction of either 'Yes' or 'No'."
)
)


@dataclass
class RerankerTrainingArguments(TrainingArguments):
model_type: str = field(default='cross-encoder', metadata={'help': "train type of cross-encoder, colbert"})
negatives_cross_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
use_inbatch_negative: bool = field(default=False)
temperature: Optional[float] = field(default=0.02)
remove_unused_columns: bool = field(default=False)
num_train_epochs: int = field(default=3)
use_lora: bool = field(default=False)
use_bnb_config: bool = field(default=False)
do_rerank: bool = field(default=False, metadata={"help": "run the reranking loop"})


def get_optimizer(model, learning_rate, weight_decay=0.0):
optimizer_parameters = [
{
"params": [p for n, p in model.model.named_parameters()],
"lr": learning_rate,
"weight_decay": weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if "model" not in n],
"lr": learning_rate * 20,
"weight_decay": 0.0,
},
]
return AdamW(optimizer_parameters)


parser = HfArgumentParser((ModelArguments, DataArguments, RerankerTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()


tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=False,
)

train_dataset = RetrievalDataset(
args=data_args,
tokenizer=tokenizer,
unfold_each_positive=data_args.unfold_each_positive,
train_group_size=data_args.train_group_size,
positive_key=data_args.positive_key,
negative_key=data_args.negative_key,
)
data_collator = LLMRerankCollator(tokenizer=tokenizer, max_length=data_args.max_length, prompt=data_args.task_prompt)
token_index = tokenizer('Yes', add_special_tokens=False)['input_ids'][-1]
model = AutoModelForRanking.from_pretrained(
model_args.model_name_or_path,
num_labels=1,
loss_fn=TokenLoss(token_index=token_index, train_group_size=data_args.train_group_size),
causal_lm=True,
use_lora=training_args.use_lora,
quantization_config=None,
)
optimizer = get_optimizer(model, learning_rate=training_args.learning_rate)

num_train_steps = int(len(train_dataset) / training_args.per_device_train_batch_size * training_args.num_train_epochs)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps
)

trainer = RerankTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
tokenizer=tokenizer,
)
trainer.optimizer = optimizer
trainer.scheduler = scheduler

trainer.train()
model.save_pretrained(training_args.output_dir)
File renamed without changes.
Empty file.
Empty file.
Loading

0 comments on commit 4d9268f

Please sign in to comment.