Skip to content

Commit

Permalink
added chat template for chatglm
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Dec 26, 2023
1 parent bc1aba7 commit 926a06c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
4 changes: 3 additions & 1 deletion src/models/args.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct ModelArgs {

// whether to use rms norm.
DEFINE_ARG(bool, use_rms_norm) = false;

// the epsilon value to use for rms norm.
DEFINE_ARG(float, rms_norm_eps) = 0.0f;

Expand Down Expand Up @@ -127,6 +127,8 @@ inline std::ostream& operator<<(std::ostream& os, const ModelArgs& args) {
os << ", attn_alibi: " << args.attn_alibi();
os << ", alibi_bias_max: " << args.alibi_bias_max();
os << ", no_bias: " << args.no_bias();
os << ", linear_bias: " << args.linear_bias();
os << ", qkv_bias: " << args.qkv_bias();
os << ", residual_post_layernorm: " << args.residual_post_layernorm();
os << "]";
return os;
Expand Down
45 changes: 39 additions & 6 deletions src/models/huggingface/chatglm.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,17 @@ class ChatGLMAttentionImpl : public torch::nn::Module {
dtype,
device));

// initialize positional embedding
// initialize attention
// initialize positional embedding and attention
const int64_t rotary_dim =
static_cast<int64_t>(head_dim * args.rotary_pct());
const float scale = 1.0f / std::sqrt(static_cast<float>(head_dim));
atten_ = register_module(
"atten",
AttentionWithRoPE(n_local_heads,
n_local_kv_heads,
head_dim,
scale,
/*rotary_dim=*/head_dim / 2,
rotary_dim,
args.rope_scaling(),
args.rope_theta(),
/*max_position=*/args.max_position_embeddings(),
Expand Down Expand Up @@ -423,26 +424,52 @@ class ChatGLMForCausalLMImpl : public torch::nn::Module {
};
TORCH_MODULE(ChatGLMForCausalLM);

class ChatGLMConversation final : public Conversation {
public:
// generate prompt from dialogs
// https://github.com/THUDM/ChatGLM3/blob/main/PROMPT.md
std::optional<std::string> get_prompt() const override {
// at least one user message
if (messages_.size() % 2 == 0) {
return std::nullopt;
}

std::stringstream ss;
if (!system_message_.empty()) {
ss << "<|system|>\n" << system_message_ << "\n";
}

// then user and assistant message pairs (u/a/u/a/u...)
for (size_t i = 0; i < messages_.size(); ++i) {
const char* role = (i % 2) == 0 ? "user" : "assistant";
ss << "<|" << role << "|>\n" << messages_[i] << "\n";
}
// end with assistant message
ss << "<|assistant|>\n";
return ss.str();
}
};

// register the model to make it available
REGISTER_CAUSAL_MODEL(chatglm, ChatGLMForCausalLM);
REGISTER_CONVERSATION_TEMPLATE(chatglm, ChatGLMConversation);
REGISTER_MODEL_ARGS(chatglm, [&] {
// example config:
// https://huggingface.co/THUDM/chatglm3-6b/blob/main/config.json
LOAD_ARG_OR(model_type, "model_type", "chatglm");
LOAD_ARG_OR(dtype, "torch_dtype", "");
LOAD_ARG_OR(vocab_size, "padded_vocab_size", 65024);
LOAD_ARG_OR(hidden_size, "hidden_size", 4096);
LOAD_ARG_OR(intermediate_size, "ffn_hidden_size", 13696);
LOAD_ARG_OR(n_layers, "num_layers", 28);
LOAD_ARG_OR(n_heads, "num_attention_heads", 32);
LOAD_ARG_OR(use_rms_norm, "rmsnorm", false);
LOAD_ARG_OR(layer_norm_eps, "layernorm_epsilon", 1e-5);
LOAD_ARG_OR(bos_token_id, "bos_token_id", 1);
LOAD_ARG_OR(eos_token_id, "eos_token_id", 2);
LOAD_ARG_OR(residual_post_layernorm,
"apply_residual_connection_post_layernorm",
false);

LOAD_ARG_OR(intermediate_size, "ffn_hidden_size", 13696);
LOAD_ARG_OR(max_position_embeddings, "seq_length", 8192);

// assign kv heads from multi_query_group_num if multi_query_attention is used
LOAD_ARG_OR_FUNC(n_kv_heads, "num_kv_attention_heads", [&] {
Expand All @@ -454,10 +481,16 @@ REGISTER_MODEL_ARGS(chatglm, [&] {
}
return n_kv_heads;
});

// rotary position embedding related args
LOAD_ARG_OR(rotary_pct, "rotary_pct", 0.5f);
LOAD_ARG_OR_FUNC(rope_theta, "rope_theta", [&] {
const float rope_ratio = json.value_or<float>("rope_ratio", 1.0f);
return rope_ratio * 10000.0f;
});

// stop token ids: "</s>", "<|user|>", "<|observation|>"
LOAD_ARG_OR(stop_token_ids, "", std::unordered_set<int32_t>({2, 64795, 64797}));
});

} // namespace llm::hf

0 comments on commit 926a06c

Please sign in to comment.