Skip to content

Commit

Permalink
add special tokens support in sentencepiece and tiktoken tokenizer.
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Dec 28, 2023
1 parent cdc11ab commit da5ea71
Show file tree
Hide file tree
Showing 7 changed files with 402 additions and 43 deletions.
1 change: 1 addition & 0 deletions src/tokenizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cc_library(
hf_tokenizer.cpp
DEPS
:sentencepiece
absl::flat_hash_map
tokenizers
glog::glog
re2::re2
Expand Down
171 changes: 155 additions & 16 deletions src/tokenizer/sentencepiece_tokenizer.cpp
Original file line number Diff line number Diff line change
@@ -1,49 +1,188 @@
#include "sentencepiece_tokenizer.h"

#include <absl/strings/str_cat.h>
#include <absl/strings/str_join.h>
#include <re2/re2.h>

#include "common/logging.h"
#include "sentencepiece.pb.h"
#include "sentencepiece/sentencepiece_processor.h"

#define RETURN_FALSE_IF_ERROR(expr) \
do { \
const auto _status = expr; \
if (!_status.ok()) return false; \
} while (0)

#define RETURN_IF_ERROR(expr) \
do { \
const auto _status = expr; \
if (!_status.ok()) return; \
} while (0)

namespace llm {

SentencePieceTokenizer::SentencePieceTokenizer(
const std::string& vocab_file_path,
const std::vector<std::string>& special_tokens,
bool prepend_bos)
: vocab_file_path_(vocab_file_path), prepend_bos_(prepend_bos) {
: vocab_file_path_(vocab_file_path),
special_tokens_(special_tokens),
prepend_bos_(prepend_bos) {
const auto status = sp_processor_.Load(vocab_file_path);
if (!status.ok()) {
GLOG(FATAL) << "Failed to load SentencePiece model from " << vocab_file_path
<< ": " << status.ToString() << ", error " << status.ToString();
}

if (special_tokens.empty()) {
// no special tokens, just return
return;
}

// add special tokens and construct special token regex
// TODO: use special token start id from tokenizer args
int32_t next_id = static_cast<int32_t>(sp_processor_.GetPieceSize());
for (const auto& token : special_tokens) {
if (token.empty()) {
continue;
}
if (!special_token_encoder_.try_emplace(token, next_id).second) {
GLOG(WARNING) << "Duplicate special token: " << token;
}
if (!special_token_decoder_.try_emplace(next_id, token).second) {
GLOG(WARNING) << "Duplicate special token id: " << next_id;
}
++next_id;
}

// build special token regex
std::vector<std::string> escaped_tokens;
escaped_tokens.reserve(special_tokens.size());
for (const auto& token : special_tokens) {
if (token.empty()) {
continue;
}
// escape each token
const auto escaped_token = re2::RE2::QuoteMeta(token);
escaped_tokens.push_back(escaped_token);
}
if (!escaped_tokens.empty()) {
const auto special_token_regex_str = absl::StrJoin(escaped_tokens, "|");
// surround with () to match special tokens
const auto regex_str = absl::StrCat("(", special_token_regex_str, ")");
special_token_regex_ = std::make_unique<re2::RE2>(regex_str);
}
}

std::unique_ptr<Tokenizer> SentencePieceTokenizer::clone() const {
return std::make_unique<SentencePieceTokenizer>(this->vocab_file_path_,
this->prepend_bos_);
bool SentencePieceTokenizer::encode_internal(const std::string_view& text,
std::vector<int32_t>* ids) const {
if (text.empty()) {
// empty text, just return
return true;
}

sentencepiece::SentencePieceText spt;
RETURN_FALSE_IF_ERROR(sp_processor_.Encode(text, &spt));
for (const auto& sp : spt.pieces()) {
ids->emplace_back(sp.id());
}
return true;
}

bool SentencePieceTokenizer::encode(const std::string_view& text,
std::vector<int32_t>* ids) const {
const auto status = sp_processor_.Encode({text.data(), text.size()}, ids);
if (!status.ok()) {
GLOG(ERROR) << "Failed to encode text: " << text << ", error "
<< status.ToString();
return false;
}
// prepend bos token
if (prepend_bos_) {
ids->insert(ids->begin(), sp_processor_.bos_id());
}
return true;

if (special_token_regex_ == nullptr) {
return encode_internal(text, ids);
}

std::string_view input = text;
std::string_view special;
while (true) {
const auto* start = input.begin();
if (!re2::RE2::FindAndConsume(&input, *special_token_regex_, &special)) {
// no more special tokens
break;
}

// encode text before special token if exists
const std::string_view sub_input(start,
input.begin() - start - special.size());
if (!encode_internal(sub_input, ids)) {
return false;
}

// add special token id if exists
const auto sit = special_token_encoder_.find(special);
if (sit != special_token_encoder_.end()) {
// find one special token
ids->push_back(sit->second);
}
}

// encode remaining text if exists
return encode_internal(input, ids);
}

void SentencePieceTokenizer::decode_internal(const std::vector<int32_t>& ids,
size_t start,
size_t end,
std::stringstream* ss) const {
if (start >= end) {
// no text to decode
return;
}

sentencepiece::SentencePieceText spt;
std::vector<std::string> pieces;
const int num_pieces = sp_processor_.GetPieceSize();
pieces.reserve(end - start);
for (size_t i = start; i < end; ++i) {
const auto id = ids[i];
if (id < 0 || id >= num_pieces) {
GLOG(ERROR) << "Invalid id: " << id;
continue;
}
pieces.emplace_back(sp_processor_.IdToPiece(id));
}
RETURN_IF_ERROR(sp_processor_.Decode(pieces, &spt));
(*ss) << spt.text();
}

std::string SentencePieceTokenizer::decode(
const std::vector<int32_t>& ids) const {
std::string text;
const auto status = sp_processor_.Decode(ids, &text);
if (!status.ok()) {
GLOG(ERROR) << "Failed to decode ids: " << status.ToString();
std::stringstream ss;
size_t start = 0;
for (size_t i = 0; i < ids.size(); ++i) {
const auto sit = special_token_decoder_.find(ids[i]);
if (sit == special_token_decoder_.end()) {
continue;
}
// decode text before special token if exists
decode_internal(ids, start, i, &ss);
// output special token
ss << sit->second;
start = i + 1;
}
return text;

// decode remaining text if exists
decode_internal(ids, start, ids.size(), &ss);
return ss.str();
}

size_t SentencePieceTokenizer::vocab_size() const {
// vocab size = sentencepiece vocab size + special tokens
return sp_processor_.GetPieceSize() + special_tokens_.size();
}

std::unique_ptr<Tokenizer> SentencePieceTokenizer::clone() const {
return std::make_unique<SentencePieceTokenizer>(
this->vocab_file_path_, this->special_tokens_, this->prepend_bos_);
}

} // namespace llm
30 changes: 28 additions & 2 deletions src/tokenizer/sentencepiece_tokenizer.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#pragma once
#include <absl/container/flat_hash_map.h>
#include <re2/re2.h>

#include <cstdint>

#include "sentencepiece/sentencepiece_processor.h"
#include "tokenizer.h"
Expand All @@ -8,22 +12,44 @@ namespace llm {
// a tokenizer that uses google/SentencePiece
class SentencePieceTokenizer : public Tokenizer {
public:
explicit SentencePieceTokenizer(const std::string& vocab_file_path, bool prepend_bos);
SentencePieceTokenizer(const std::string& vocab_file_path,
const std::vector<std::string>& special_tokens,
bool prepend_bos);

SentencePieceTokenizer(const std::string& vocab_file_path, bool prepend_bos)
: SentencePieceTokenizer(vocab_file_path, {}, prepend_bos) {}

bool encode(const std::string_view& text,
std::vector<int32_t>* ids) const override;

std::string decode(const std::vector<int32_t>& ids) const override;

size_t vocab_size() const override { return sp_processor_.GetPieceSize(); }
size_t vocab_size() const override;

std::unique_ptr<Tokenizer> clone() const override;

private:
bool encode_internal(const std::string_view& text,
std::vector<int32_t>* ids) const;
void decode_internal(const std::vector<int32_t>& ids,
size_t start,
size_t end,
std::stringstream* ss) const;
std::string vocab_file_path_;

std::vector<std::string> special_tokens_;

sentencepiece::SentencePieceProcessor sp_processor_;

// special tokens to ids
absl::flat_hash_map<std::string, int32_t> special_token_encoder_;

// special token ids to tokens
absl::flat_hash_map<int32_t, std::string> special_token_decoder_;

// special token regex (optional)
std::unique_ptr<re2::RE2> special_token_regex_;

bool prepend_bos_ = false;
};

Expand Down
43 changes: 41 additions & 2 deletions src/tokenizer/sentencepiece_tokenizer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@ namespace llm {

TEST(SentencePieceTokenizerTest, EncodeDecodeTest) {
SentencePieceTokenizer tokenizer("data/tokenizer.model", /*prepend_bos=*/true);
EXPECT_EQ(tokenizer.vocab_size(), 32000);
const std::string test_text = "Hello, world!";
std::vector<int> ids;
ASSERT_TRUE(tokenizer.encode("Hello, world!", &ids));
ASSERT_TRUE(tokenizer.encode(test_text, &ids));
const std::vector<int> desired_ids = {1, 15043, 29892, 3186, 29991};
EXPECT_EQ(ids, desired_ids);

const auto text = tokenizer.decode(ids);
EXPECT_EQ(text, "Hello, world!");
EXPECT_EQ(text, test_text);
}

TEST(SentencePieceTokenizerTest, CJKTest) {
SentencePieceTokenizer tokenizer("data/tokenizer.model", /*prepend_bos=*/true);
EXPECT_EQ(tokenizer.vocab_size(), 32000);
const std::string test_text = "你好,世界!";
std::vector<int> ids;
ASSERT_TRUE(tokenizer.encode(test_text, &ids));
Expand All @@ -27,4 +30,40 @@ TEST(SentencePieceTokenizerTest, CJKTest) {
const auto decoded_text = tokenizer.decode(ids);
EXPECT_EQ(decoded_text, test_text);
}

TEST(SentencePieceTokenizerTest, SpecialTokenTest) {
// clang-format off
std::vector<std::string> special_tokens = {
"[gMASK]", "[sMASK]", "sop", "eop",
"<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"
};
// clang-format on
SentencePieceTokenizer tokenizer("data/tokenizer.model", special_tokens, /*prepend_bos=*/false);
EXPECT_EQ(tokenizer.vocab_size(), 32000 + special_tokens.size());
// test encode each special token
for (const auto& token : special_tokens) {
std::vector<int> ids;
ASSERT_TRUE(tokenizer.encode(token, &ids));
EXPECT_EQ(ids.size(), 1);
const auto decoded_token = tokenizer.decode(ids);
EXPECT_EQ(decoded_token, token);
}

// test encode text with special tokens
const std::string test_text =
"<|system|> Hello world <|user|> Hello <|assistant|>";
std::vector<int> ids;
ASSERT_TRUE(tokenizer.encode(test_text, &ids));
// clang-format off
const std::vector<int> desired_ids = {
32004, 29871, 15043, 3186, 29871,
32005, 29871, 15043, 29871,
32006
};
// clang-format on
EXPECT_EQ(ids, desired_ids);

const auto text = tokenizer.decode(ids);
EXPECT_EQ(text, test_text);
}
} // namespace llm
Loading

0 comments on commit da5ea71

Please sign in to comment.