Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Optimized some scattered optimization points in the framework #5544

Open
wants to merge 1 commit into
base: feature/colossal-infer
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]
0, max_num_blocks
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
output_tensor = torch.zeros(
output_tensor = torch.empty(
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
)
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
Expand Down Expand Up @@ -371,13 +371,13 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor,

sequence_lengths = batch.get_sequence_lengths()
if batch.is_prompts:
output_tensor = torch.zeros(
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim),
output_tensor = torch.empty(
(input_ids.size(0), batch.num_heads * batch.head_dim),
dtype=batch.dtype,
device=batch.device,
)
else:
output_tensor = torch.zeros(
output_tensor = torch.empty(
(batch.current_batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)

Expand Down
12 changes: 7 additions & 5 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,13 @@ def search_tokens(self, generation_config: GenerationConfig, logits):
Sample tokens for finished requests.
"""
# do logit processor
# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
for type in ["top_k", "top_p", "min_p"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type])
top_p = generation_config.top_p
top_k = generation_config.top_k

if top_k:
logits = logit_processor("top_k", logits, top_k)
if top_p:
logits = logit_processor("top_p", logits, top_p)
yuanheng-zhao marked this conversation as resolved.
Show resolved Hide resolved

# calculate probs
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Expand Down
40 changes: 26 additions & 14 deletions colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_block_table_kv_ptrs(self, block_table: torch.Tensor, layer_id: int) -> T
"""Get the key and value pointers of physical caches (of specific layer) corresponding to logical cache blocks indicated by the block table."""
k_ptrs = []
v_ptrs = []
for block_id in block_table:
for block_id in block_table.tolist():
if block_id >= 0:
block: CacheBlock = self._cache_blocks[block_id]
k_ptrs.append(block.k_ptrs[layer_id])
Expand Down Expand Up @@ -223,46 +223,58 @@ def allocate_context_from_block_tables(self, block_tables: torch.Tensor, context
self._block_states_cum[:-num_blocks_required],
out=self._block_finder[num_blocks_required - 1 :],
)

end_indexes = torch.nonzero(self._block_finder == num_blocks_required, as_tuple=False).view(-1)

block_tables_list = block_tables.tolist()
blocks_required_list = blocks_required.tolist()

if end_indexes.numel() > 0:
# contiguous cache exists
end_idx = end_indexes[0].item() + 1 # open interval
start_idx = end_idx - num_blocks_required # closed interval
alloc_block_ids = torch.arange(start_idx, end_idx)
alloc_block_ids = torch.arange(start_idx, end_idx, device=block_tables.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assigning alloc_block_ids.device to that of block_tables might trigger error in L259

self._block_states[alloc_block_ids] = 0

Notice that self._block_states is on the host. If the passed-in block tables tensor was on a device, you will get runtime error Expected all tensors to be on the same device, but found ....

At this moment, there exist no difference of adding device=block_tables.device here, since in batch bucket class the block tables tensor is on host, which cause no error and no functionality here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will fix it.


for i in range(bsz):
curr_required = blocks_required[i]
block_tables[i, :curr_required] = torch.arange(
start_idx, start_idx + curr_required, device=block_tables.device
)
curr_required = blocks_required_list[i]
for j in range(curr_required):
block_tables_list[i][j] = start_idx + j
start_idx += curr_required
else:
# non-contiguous cache
available_block_ids = torch.nonzero(self._block_states > 0).view(-1)
alloc_block_ids = available_block_ids[:num_blocks_required]
alloc_block_ids = alloc_block_ids.to(dtype=block_tables.dtype, device=block_tables.device)
alloc_block_ids_list = alloc_block_ids.tolist()
yuanheng-zhao marked this conversation as resolved.
Show resolved Hide resolved
start_idx = 0
for i in range(bsz):
curr_required = blocks_required[i]
block_tables[i, :curr_required] = alloc_block_ids[start_idx, start_idx + curr_required]
curr_required = blocks_required_list[i]
for j in range(curr_required):
block_tables_list[i][j] = alloc_block_ids_list[start_idx + j]
start_idx += curr_required

block_tables.copy_(torch.tensor(block_tables_list))

# Update cache blocks
self._block_states[alloc_block_ids] = 0
self._available_blocks -= num_blocks_required
last_block_locs = torch.cumsum(blocks_required, dim=0) - 1
last_block_locs = last_block_locs.to(device=alloc_block_ids.device)

for i, block_id in enumerate(alloc_block_ids[last_block_locs]):
last_block_ids = alloc_block_ids[last_block_locs].tolist()
alloc_block_ids = alloc_block_ids.tolist()
context_lengths = context_lengths.tolist()

for i, block_id in enumerate(last_block_ids):
block: CacheBlock = self._cache_blocks[block_id]
block.add_ref()
self._allocate_on_block(
block,
block.block_size
if context_lengths[i] % block.block_size == 0
else context_lengths[i].item() % block.block_size,
else context_lengths[i] % block.block_size,
)
for block_id in alloc_block_ids:
if block_id in alloc_block_ids[last_block_locs]:
if block_id in last_block_ids:
continue
block: CacheBlock = self._cache_blocks[block_id]
block.add_ref()
Expand Down Expand Up @@ -336,15 +348,15 @@ def allocate_tokens_from_block_tables(
dtype=block_tables.dtype, device=block_tables.device
)

for block_id in alloc_block_ids:
for block_id in alloc_block_ids.tolist():
block: CacheBlock = self._cache_blocks[block_id]
block.add_ref()
self._block_states[block_id] = 0
self._available_blocks -= 1
block_tables[seqs_req_new_blocks, alloc_local_block_indexes[seqs_req_new_blocks]] = alloc_block_ids
block_global_ids = block_tables[torch.arange(0, bsz), alloc_local_block_indexes]

for block_id in block_global_ids:
for block_id in block_global_ids.tolist():
self._allocate_on_block(self._cache_blocks[block_id], 1)

return seqs_to_recycle
Expand Down
13 changes: 9 additions & 4 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def llama_model_forward(
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)

norm_output = torch.empty_like(hidden_states)
silu_and_mul_output = torch.empty(
hidden_states.size(0), self.config.intermediate_size, dtype=hidden_states.dtype, device=hidden_states.device
)
residual = None

for layer_id, decoder_layer in enumerate(self.layers):
Expand All @@ -137,6 +140,7 @@ def llama_model_forward(
kv_seq_len=kv_seq_len,
output_tensor=output_tensor,
norm_output=norm_output,
silu_and_mul_output=silu_and_mul_output,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
cu_seqlens=cu_seqlens,
Expand Down Expand Up @@ -167,6 +171,7 @@ def llama_decoder_layer_forward(
kv_seq_len: int = 0,
output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None,
silu_and_mul_output: torch.Tensor = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it's a good idea to just put the silu_and_mul output tensor as an arg and pass it module by module to MLP layer.

Copy link
Contributor Author

@yuehuayingxueluo yuehuayingxueluo Apr 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I also feel that there are too many parameters to pass like this, I feel that we can put all these temporary outputs into a struct for unified management in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we only need to pass this struct each time."

Copy link
Contributor

@Courtesy-Xs Courtesy-Xs Apr 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some advice :) . Firstly, It's not a good idea to design a such ACT API that you should add a output_tensor as an arg, if you really want to do such things, you'd better make it a inplace API. Secondly, I don't think it's a good idea to help torch to do such memory management by you own before you really understand it or you've already designed a great memory management system, meanwhile, the profit of performance seems little and maybe it's just normal value fluctuation, so that this opt point may not work well. finally, maybe it's not a good idea to write trick code for just little performance profit.

Copy link
Contributor Author

@yuehuayingxueluo yuehuayingxueluo Apr 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During testing, it is possible to obtain a stable performance benefit, moreover, compared to other optimizations, such performance benefits already seem quite considerable. Also, this does not involve helping torch manage memory; instead, it should be attributed to our unreasonable use of memory. Of course, I also agree that this operator should be implemented as an inplace operator, which will avoid redundant memory allocation operations.

Copy link
Contributor Author

@yuehuayingxueluo yuehuayingxueluo Apr 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that this is only a temporary optimization solution, and the optimal solution would be to implement this operator as an inplace one. And we can put a TODO here.

sm_scale: int = None,
use_cuda_kernel: bool = True,
cu_seqlens: torch.Tensor = None,
Expand Down Expand Up @@ -216,7 +221,7 @@ def llama_decoder_layer_forward(

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, silu_and_mul_output)

return hidden_states, residual

Expand Down Expand Up @@ -481,12 +486,12 @@ def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:

return mlp_layer

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, silu_and_mul_output: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
"""
hidden_states = hidden_states.expand(2, -1, -1)
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
return torch.mm(act_out, self.down_proj_weight)
inference_ops.silu_and_mul(gate_up_proj_out, silu_and_mul_output)
return torch.mm(silu_and_mul_output, self.down_proj_weight)
1 change: 1 addition & 0 deletions examples/inference/benchmark_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def benchmark_inference(args):
max_output_len=args.output_len,
prefill_ratio=1.2,
block_size=32,
use_cuda_kernel=True,
)
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
elif args.mode == "vllm":
Expand Down
13 changes: 1 addition & 12 deletions extensions/csrc/cuda/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,8 @@ __global__ void act_and_mul_kernel(

// Note(LiuYang):This func is designed for calculation mode like
// silu(x[:half_1stdim]) * (x[half_1stdim:])
torch::Tensor silu_and_mul(const torch::Tensor& ins)
void silu_and_mul(const torch::Tensor& ins, torch::Tensor& outs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't handle the condition of outs is None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a None value is passed in, it will be an illegal operation and C++ will report an error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the case should be considered, whether you dispatch to a different kernel or not. The modifications here make it lose the capabilities of handling the regular way of calling the kernel (only inputs).

Copy link
Contributor Author

@yuehuayingxueluo yuehuayingxueluo Apr 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let me think about how to fix it
.

{
// Note(LiuYang): According to torch doc, vec() may cost a lot, but I did't find a better api
// to manipulate ins_shape which is IntArrayRef
auto ins_shape = ins.sizes().vec();

ins_shape[0] = ins_shape[0]/2;
if (ins_shape[0] == 1) {
ins_shape.erase(ins_shape.begin());
}
auto outs = torch::zeros(ins_shape,ins.options());

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// Note(Liuyang): numel of ins must be divisible by 2
Expand All @@ -71,5 +61,4 @@ torch::Tensor silu_and_mul(const torch::Tensor& ins)
);)

AT_CUDA_CHECK(cudaGetLastError());
return outs;
}
2 changes: 1 addition & 1 deletion extensions/csrc/cuda/pybind/inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void rotary_embedding_and_cache_copy(
torch::Tensor& block_tables, // [batch_size, max_seq_len]
bool high_precision);

torch::Tensor silu_and_mul(const torch::Tensor& ins);
void silu_and_mul(const torch::Tensor& ins, torch::Tensor& outs);

void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
Expand Down
4 changes: 3 additions & 1 deletion tests/test_infer/test_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def check_inference_engine(use_engine=False, prompt_template=None):
top_k = 50

if use_engine:
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
inference_config = InferenceConfig(
max_output_len=output_len, prompt_template=prompt_template, use_cuda_kernel=True, dtype="fp32"
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_infer/test_ops/cuda/test_silu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def test_silu_and_mul(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype):
act_out = torch.nn.functional.silu(ref_input[0], inplace=True)
ref_out = act_out * ref_input[1]

origin_out = inference_ops.silu_and_mul(origin_input)
origin_out = torch.empty_like(ref_out)
inference_ops.silu_and_mul(origin_input, origin_out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above. No test for None as output tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above reply


if dtype == torch.float32:
assert torch.allclose(origin_out, ref_out, atol=1e-5, rtol=1e-5)
Expand Down
Loading