Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Clean duplicated vector utils #5715

Open
wants to merge 182 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
182 commits
Select commit Hold shift + click to select a range
4cf4682
[Inference] First PR for rebuild colossal-infer (#5143)
CjhHa1 Dec 1, 2023
56e75ee
[Inference] Add readme (roadmap) and fulfill request handler (#5147)
CjhHa1 Dec 1, 2023
2bb9224
[Inference/NFC] Clean outdated inference tests and deprecated kernels…
yuanheng-zhao Dec 5, 2023
fab9b93
[Inference]Add BatchInferState, Sequence and InferConfig (#5149)
yuehuayingxueluo Dec 7, 2023
3de2e62
[Inference] Add CacheBlock and KV-Cache Manager (#5156)
yuanheng-zhao Dec 11, 2023
93aeacc
[Inference]Update inference config and fix test (#5178)
CjhHa1 Dec 12, 2023
8daee26
[Inference] Add the logic of the inference engine (#5173)
yuehuayingxueluo Dec 18, 2023
0e61646
[Inference] add logit processor and request handler (#5166)
CjhHa1 Dec 25, 2023
86853a3
Add padding llama model
yuehuayingxueluo Dec 25, 2023
62fd08e
Fixed a bug in the inference frame
yuehuayingxueluo Dec 26, 2023
6296858
fix bugs in request_handler
yuehuayingxueluo Jan 2, 2024
9489dc6
precision alignment
yuehuayingxueluo Jan 2, 2024
4df8876
Fixed a writing error
yuehuayingxueluo Jan 2, 2024
07b5283
[kernel] Add triton kernel for context attention (FAv2) without paddi…
yuanheng-zhao Jan 3, 2024
02c1bf8
add context_attention_unpadded
yuehuayingxueluo Jan 3, 2024
bbfebfb
fix bugs in sampler
yuehuayingxueluo Jan 4, 2024
b2eb9cd
Fixed a typo
yuehuayingxueluo Jan 4, 2024
3ad1f3b
fix beam_width
yuehuayingxueluo Jan 4, 2024
bfd9b1b
[Inference] Pytorch Attention func, pad&nopad input support (#5219)
CjhHa1 Jan 4, 2024
47e53ea
fix bugs in attention.py and request_handler.py
yuehuayingxueluo Jan 8, 2024
fa4fbdb
adapted to pad_context_forward
yuehuayingxueluo Jan 9, 2024
e545a87
[Hotfix] Fix accuracy and align attention method api with Triton kern…
CjhHa1 Jan 8, 2024
2a73e82
fix bugs related to processing padding mask
yuehuayingxueluo Jan 9, 2024
fab294c
fix CI bugs
yuehuayingxueluo Jan 9, 2024
10e3c9f
rm torch.cuda.synchronize
yuehuayingxueluo Jan 9, 2024
d40eb26
fix bugs in request_handler.py and engine.py
yuehuayingxueluo Jan 10, 2024
fded91d
[Inference] Kernel: no pad rotary embedding (#5252)
CjhHa1 Jan 11, 2024
1513f20
[kernel] Add flash decoding triton kernel for blocked kv cache (#5249)
yuanheng-zhao Jan 11, 2024
1ded7e8
[git] fixed rebased files
FrankLeeeee Jan 11, 2024
fa85e02
[kernel] Add KV cache copy kernel during decoding (#5261)
yuanheng-zhao Jan 15, 2024
c597678
[doc] updated inference readme (#5269)
FrankLeeeee Jan 15, 2024
d8db500
[Inference] Fix request handler and add recycle logic (#5260)
CjhHa1 Jan 15, 2024
0f2b46a
[kernel] Revise KVCache copy triton kernel API (#5273)
yuanheng-zhao Jan 16, 2024
86b63f7
[Inference]Adapted to the triton attn kernels (#5264)
yuehuayingxueluo Jan 17, 2024
5ae9099
[kernel] Add RMSLayerNorm triton kernel (#5262)
nkfyz Jan 18, 2024
9e2342b
[Hotfix] Fix bugs in testing continuous batching (#5270)
CjhHa1 Jan 18, 2024
6e487e7
[kernel/fix] Performance Optimization for Decoding Kernel and Benchma…
yuanheng-zhao Jan 19, 2024
bfff925
[inference] Adapted to Rotary Embedding and RMS Norm (#5283)
yuehuayingxueluo Jan 22, 2024
cea9c86
add utils.py
yuehuayingxueluo Jan 22, 2024
b785319
Merge pull request #5297 from yuehuayingxueluo/fix_rotary_embedding
yuehuayingxueluo Jan 22, 2024
8e606ec
[Inference] Benchmarking rotary embedding and add a fetch function (#…
CjhHa1 Jan 23, 2024
3da9993
[Kernel/Fix] Revise flash attention triton kernel API and add benchma…
yuanheng-zhao Jan 23, 2024
c647e00
[Inference]Add fused rotary kernel and get cos cache kernel (#5302)
CjhHa1 Jan 24, 2024
af8359c
[hotfix] fix boundary check in batch (#5306)
yuanheng-zhao Jan 25, 2024
4f28cb4
[inference]Optimize the usage of the mid tensors space in flash attn …
yuehuayingxueluo Jan 26, 2024
7ddd8b3
fix (#5311)
CjhHa1 Jan 26, 2024
1f8a75d
[Inference] Update rms norm kernel, benchmark with vLLM (#5315)
CjhHa1 Jan 29, 2024
c7c104c
[DOC] Update inference readme (#5280)
CjhHa1 Jan 29, 2024
e8f0642
[Inference]Add Nopadding Llama Modeling (#5327)
yuehuayingxueluo Jan 30, 2024
5f98a9d
[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)
yuanheng-zhao Jan 30, 2024
c565519
merge commit
FrankLeeeee Jan 31, 2024
1336838
Merge pull request #5339 from FrankLeeeee/sync/merge-main
FrankLeeeee Jan 31, 2024
df0aa49
[Inference] Kernel Fusion, fused copy kv cache into rotary embedding …
CjhHa1 Jan 31, 2024
f8e456d
[inference] simplified config verification (#5346)
FrankLeeeee Feb 1, 2024
249644c
[Inference]Repalce Attention layer and MLP layer by shardformer to op…
yuehuayingxueluo Feb 1, 2024
db1a763
[inference] removed redundancy init_batch (#5353)
FrankLeeeee Feb 2, 2024
e76acbb
[inference] moved ops tests to test_infer (#5354)
FrankLeeeee Feb 2, 2024
027aa10
[doc] updated inference readme (#5343)
FrankLeeeee Feb 2, 2024
21ad4a2
[Inference/opt]Optimize the mid tensor of RMS Norm (#5350)
yuehuayingxueluo Feb 2, 2024
631862f
[Inference]Optimize generation process of inference engine (#5356)
yuehuayingxueluo Feb 2, 2024
1dedb57
[Fix/Infer] Remove unused deps and revise requirements (#5341)
yuanheng-zhao Feb 6, 2024
35382a7
[Inference]Fused the gate and up proj in mlp,and optimized the autogr…
yuehuayingxueluo Feb 6, 2024
9f4ab2e
[Inference] Adapt to Fused rotary (#5348)
CjhHa1 Feb 7, 2024
8106ede
Revert "[Inference] Adapt to Fused rotary (#5348)" (#5373)
FrankLeeeee Feb 7, 2024
58740b5
[inference] added inference template (#5375)
FrankLeeeee Feb 7, 2024
6fb4bcb
[Inference/opt] Fused KVCahce Memcopy (#5374)
yuehuayingxueluo Feb 7, 2024
1f8c7e7
[Inference] User Experience: update the logic of default tokenizer an…
CjhHa1 Feb 7, 2024
9afa520
[inference] refactored config (#5376)
FrankLeeeee Feb 8, 2024
8c69deb
[Inference]Support vllm testing in benchmark scripts (#5379)
yuehuayingxueluo Feb 8, 2024
b21aac5
[Inference] Optimize and Refactor Inference Batching/Scheduling (#5367)
yuanheng-zhao Feb 19, 2024
7301038
[Inference]Fused kv copy into rotary calculation (#5383)
CjhHa1 Feb 21, 2024
2a718c8
Optimized the execution interval time between cuda kernels caused by …
yuehuayingxueluo Feb 21, 2024
bc1da87
[Fix/Inference] Fix format of input prompts and input model in infer…
yuehuayingxueluo Feb 23, 2024
1906118
[Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399)
yuanheng-zhao Feb 26, 2024
600881a
[Inference]Add CUDA KVCache Kernel (#5406)
yuehuayingxueluo Feb 28, 2024
0aa27f1
[Inference]Move benchmark-related code to the example directory. (#5408)
yuehuayingxueluo Feb 28, 2024
0310b76
Merge branch 'main' into sync/main
FrankLeeeee Mar 4, 2024
593a72e
Merge pull request #5424 from FrankLeeeee/sync/main
FrankLeeeee Mar 4, 2024
95c2149
add silu_and_mul for infer
Courtesy-Xs Mar 7, 2024
cefaeb5
[feat] cuda graph support and refactor non-functional api
LRY89757 Mar 8, 2024
2b28b54
Merge pull request #5433 from Courtesy-Xs/add_silu_and_mul
Courtesy-Xs Mar 8, 2024
a46598a
add reusable utils for cuda
Courtesy-Xs Mar 8, 2024
01d289d
Merge branch 'feature/colossal-infer' of https://github.com/hpcaitech…
Courtesy-Xs Mar 8, 2024
5eb5ff1
refactor code
Courtesy-Xs Mar 8, 2024
f7aecc0
feat rmsnorm cuda kernel and add unittest, benchmark script (#5417)
SunflowerAries Mar 8, 2024
b2c0d9f
[fix] multi graphs capture error
LRY89757 Mar 11, 2024
9dec66f
[fix] multi graphs capture error
LRY89757 Mar 11, 2024
633e95b
[doc] add doc
LRY89757 Mar 11, 2024
21e1e36
Merge pull request #5435 from Courtesy-Xs/add_gpu_launch_config
Courtesy-Xs Mar 11, 2024
095c070
refactor code
Courtesy-Xs Mar 11, 2024
368a2aa
Merge pull request #5445 from Courtesy-Xs/refactor_infer_compilation
Courtesy-Xs Mar 12, 2024
b699f54
optimize rmsnorm: add vectorized elementwise op, feat loop unrolling …
SunflowerAries Mar 12, 2024
c1c45e9
fix include path
Courtesy-Xs Mar 13, 2024
6fd355a
Merge pull request #5452 from Courtesy-Xs/fix_include_path
Courtesy-Xs Mar 13, 2024
ed431de
fix rmsnorm template function invocation problem(template function pa…
SunflowerAries Mar 13, 2024
f366a5e
[Inference/kernel]Add Fused Rotary Embedding and KVCache Memcopy CUDA…
yuehuayingxueluo Mar 13, 2024
1821a6d
[fix] pytest and fix dyn grid bug
LRY89757 Mar 13, 2024
ae24b4f
diverse tests
LRY89757 Mar 14, 2024
d02e257
Merge branch 'feature/colossal-infer' into colossal-infer-cuda-graph
LRY89757 Mar 14, 2024
388e043
add implementatino for GetGPULaunchConfig1D
Courtesy-Xs Mar 14, 2024
6e30248
[fix] tmp for test
LRY89757 Mar 14, 2024
5724b9e
add some comments
Courtesy-Xs Mar 15, 2024
b6e9785
Merge pull request #5457 from Courtesy-Xs/ly_add_implementation_for_l…
Courtesy-Xs Mar 15, 2024
48c4f29
refactor vector utils
Courtesy-Xs Mar 19, 2024
aabc9fb
[feat] add use_cuda_kernel option
LRY89757 Mar 19, 2024
b96557b
Merge pull request #5469 from Courtesy-Xs/add_vec_traits
Courtesy-Xs Mar 19, 2024
7ff42cc
add vec_type_trait implementation (#5473)
Courtesy-Xs Mar 19, 2024
4eafe0c
[fix] unused option
LRY89757 Mar 21, 2024
606603b
Merge branch 'feature/colossal-infer' of https://github.com/hpcaitech…
LRY89757 Mar 21, 2024
5b017d6
[fix]
LRY89757 Mar 21, 2024
9fe61b4
[fix]
LRY89757 Mar 25, 2024
ff4998c
[fix] remove unused comment
LRY89757 Mar 25, 2024
87079cf
[Inference]Support FP16/BF16 Flash Attention 2 And Add high_precision…
yuehuayingxueluo Mar 25, 2024
68e9396
[fix] merge conflicts
LRY89757 Mar 25, 2024
1d62623
Merge pull request #5434 from LRY89757/colossal-infer-cuda-graph
LRY89757 Mar 25, 2024
6251d68
[fix] PR #5354 (#5501)
LRY89757 Mar 25, 2024
e6496dd
[Inference] Optimize request handler of llama (#5512)
Courtesy-Xs Mar 26, 2024
934e31a
The writing style of tail processing and the logic related to macro d…
yuehuayingxueluo Mar 28, 2024
04aca9e
[Inference/Kernel]Add get_cos_and_sin Kernel (#5528)
yuehuayingxueluo Apr 1, 2024
a2878e3
[Inference] Add Reduce Utils (#5537)
Courtesy-Xs Apr 1, 2024
4bb5d89
[Fix/Inference] Remove unused and non-functional functions (#5543)
yuanheng-zhao Apr 2, 2024
7ebdf48
add cast and op_functor for cuda build-in types (#5546)
Courtesy-Xs Apr 8, 2024
ed5ebd1
[Fix] resolve conflicts of merging main
yuanheng-zhao Apr 8, 2024
ce9401a
remove unused triton kernels
yuanheng-zhao Apr 8, 2024
d788175
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2024
7ca1d1c
remove outdated triton test
yuanheng-zhao Apr 8, 2024
d56c963
Sync main to feature/colossal-infer
yuanheng-zhao Apr 9, 2024
d63c469
[Infer] Revise and Adapt Triton Kernels for Spec-Dec (#5401)
yuanheng-zhao Feb 28, 2024
5a9b05f
[Inference/SpecDec] Add Basic Drafter Model Container (#5405)
yuanheng-zhao Feb 28, 2024
a37f826
[Inference/SpecDec] Add Speculative Decoding Implementation (#5423)
yuanheng-zhao Mar 11, 2024
912e24b
[SpecDec] Fix inputs for speculation and revise past KV trimming (#5449)
yuanheng-zhao Mar 12, 2024
d85d914
[Inference/SpecDec] Support GLIDE Drafter Model (#5455)
yuanheng-zhao Apr 1, 2024
e1acb58
[doc] Add inference/speculative-decoding README (#5552)
yuanheng-zhao Apr 3, 2024
e60d430
[Fix] resolve conflicts of rebasing feat/speculative-decoding (#5557)
yuanheng-zhao Apr 7, 2024
f8598e3
[Fix] Llama Modeling Control with Spec-Dec (#5580)
yuanheng-zhao Apr 10, 2024
25928d8
[Inference/Spec-Dec] Merge pull request #5565 from hpcaitech/feat/spe…
yuanheng-zhao Apr 10, 2024
a219123
refactor csrc (#5582)
Courtesy-Xs Apr 11, 2024
d4cb023
[Inference/Refactor] Delete Duplicated code and refactor vec_copy uti…
Courtesy-Xs Apr 15, 2024
56b222e
[inference/model]Adapted to the baichuan2-7B model (#5591)
yuehuayingxueluo Apr 15, 2024
be396ad
[Inference/Kernel] Add Paged Decoding kernel, sequence split within t…
SunflowerAries Apr 18, 2024
e37ee2f
[Feat]Tensor Model Parallel Support For Inference (#5563)
LRY89757 Apr 18, 2024
ccf7279
feat baichuan2 rmsnorm whose hidden size equals to 5120 (#5611)
SunflowerAries Apr 19, 2024
5d4c1fe
[Fix/Inference] Fix GQA Triton and Support Llama3 (#5624)
yuanheng-zhao Apr 23, 2024
12f10d5
[Fix/Inference]Fix CUDA Rotary Rmbedding GQA (#5623)
yuehuayingxueluo Apr 23, 2024
04863a9
[example] Update Llama Inference example (#5629)
yuanheng-zhao Apr 23, 2024
279300d
[Inference/Refactor] Refactor compilation mechanism and unified multi…
Courtesy-Xs Apr 24, 2024
90cd522
[Fix/Inference]Fix vllm benchmark (#5630)
yuehuayingxueluo Apr 24, 2024
a8fd3b0
[Inference/Kernel] Optimize paged attention: Refactor key cache layou…
SunflowerAries Apr 25, 2024
f342a93
[Fix] Remove obsolete files - inference (#5650)
yuanheng-zhao Apr 25, 2024
3c91e3f
[Inference]Adapt to baichuan2 13B (#5614)
yuehuayingxueluo Apr 25, 2024
5be590b
[kernel] Support new KCache Layout - Context Attention Triton Kernel …
yuanheng-zhao Apr 26, 2024
8ccb671
[Inference/Feat] Add kvcache quantization support for FlashDecoding (…
Courtesy-Xs Apr 26, 2024
808ee6e
[Inference/Feat] Feat quant kvcache step2 (#5674)
Courtesy-Xs Apr 30, 2024
5f00002
[Inference] Adapt Baichuan2-13B TP (#5659)
yuehuayingxueluo Apr 30, 2024
5cd75ce
[Inference/Kernel] refactor kvcache manager and rotary_embedding and …
SunflowerAries Apr 30, 2024
ef8e4ff
[Inference/Feat] Add kvcache quant support for fused_rotary_embedding…
Courtesy-Xs Apr 30, 2024
f799631
[inference]Add alibi to flash attn function (#5678)
yuehuayingxueluo Apr 30, 2024
9df016f
[Inference] Fix quant bits order (#5681)
Courtesy-Xs Apr 30, 2024
537a3cb
[kernel] Support New KCache Layout - Triton Kernel (#5677)
yuanheng-zhao May 3, 2024
56ed09a
[sync] resolve conflicts of merging main
yuanheng-zhao May 5, 2024
8754aba
[Fix] Fix & Update Inference Tests (compatibility w/ main)
yuanheng-zhao May 5, 2024
725fbd2
[Inference] Remove unnecessary float4_ and rename float8_ to float8 (…
SunflowerAries May 6, 2024
db7b305
[Sync] Update from main to feature/colossal-infer (Merge pull request…
yuanheng-zhao May 6, 2024
1ace106
[Inference/Feat] Add quant kvcache support for decode_kv_cache_memcpy…
Courtesy-Xs May 6, 2024
f9afe0a
[hotfix] Fix KV Heads Number Assignment in KVCacheManager (#5695)
yuanheng-zhao May 7, 2024
55cc7f3
[Fix] Fix Inference Example, Tests, and Requirements (#5688)
yuanheng-zhao May 8, 2024
12e7c28
[hotfix] fix OpenMOE example import path (#5697)
yuanheng-zhao May 8, 2024
9c2fe79
[Inference]Adapt temperature processing logic (#5689)
yuehuayingxueluo May 8, 2024
d482922
[Inference] Support the logic related to ignoring EOS token (#5693)
yuehuayingxueluo May 8, 2024
69cd7e0
[Inference] ADD async and sync Api server using FastAPI (#5396)
CjhHa1 Mar 1, 2024
de378cd
[Inference] Finish Online Serving Test, add streaming output api, con…
CjhHa1 Mar 18, 2024
c064032
[Online Server] Chat Api for streaming and not streaming response (#5…
CjhHa1 Apr 7, 2024
7bbb28e
[Inference] resolve rebase conflicts
CjhHa1 Apr 11, 2024
61a1b2e
[Inference] Fix bugs and docs for feat/online-server (#5598)
CjhHa1 May 8, 2024
bc9063a
resolve rebase conflicts on Branch feat/online-serving
CjhHa1 May 8, 2024
5d9a494
[Inference] Add example test_ci script
CjhHa1 May 9, 2024
492520d
Merge pull request #5588 from hpcaitech/feat/online-serving
CjhHa1 May 9, 2024
bfad393
[Inference/Feat] Add quant kvcache interface (#5700)
Courtesy-Xs May 9, 2024
50104ab
[Inference/Feat] Add convert_fp8 op for fp8 test in the future (#5706)
Courtesy-Xs May 10, 2024
de4bf3d
[Inference]Adapt repetition_penalty and no_repeat_ngram_size (#5708)
yuehuayingxueluo May 11, 2024
18d67d0
[Feat]Inference RPC Server Support (#5705)
LRY89757 May 14, 2024
30ea54f
delete copy_vector
Courtesy-Xs May 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ jobs:
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 60
timeout-minutes: 75
defaults:
run:
shell: bash
Expand Down
347 changes: 156 additions & 191 deletions colossalai/inference/README.md

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions colossalai/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .engine import InferenceEngine
from .engine.policies import BloomModelInferPolicy, ChatGLM2InferPolicy, LlamaModelInferPolicy
from .config import InferenceConfig
from .core import InferenceEngine

__all__ = ["InferenceEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"]
__all__ = ["InferenceConfig", "InferenceEngine"]
523 changes: 523 additions & 0 deletions colossalai/inference/batch_bucket.py

Large diffs are not rendered by default.

341 changes: 341 additions & 0 deletions colossalai/inference/config.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions colossalai/inference/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .engine import InferenceEngine
from .request_handler import RequestHandler

__all__ = ["InferenceEngine", "RequestHandler"]
309 changes: 309 additions & 0 deletions colossalai/inference/core/async_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
import asyncio
import logging
from functools import partial
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type

from colossalai.inference.core.engine import InferenceEngine

# CLI logger
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("colossalai-inference")


def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None:
msg = "Task finished unexpectedly. This should never happen! "
try:
try:
task.result()
except asyncio.CancelledError:
return
except Exception as exc:
raise RuntimeError(msg + " See stack trace above for the actual cause.") from exc
raise RuntimeError(msg)
except Exception as exc:
request_tracker.propagate_exception(exc)
raise exc


class RequstStream:
"""
A stream of Output for a request that can be iterated over asynchronously.
Attributes: 1.request_id: The id of the request.
2._future: A future that will be set when the request is finished.
Methods: set_result and get_result, results will be set when finished, for once, and
the `self.future` will be set to done.

"""

def __init__(self, request_id: int) -> None:
self.request_id = request_id
self._future = asyncio.Future()

def set_result(self, result) -> None:
"""Set final result and signal taht it's ready"""
if not self._future.done():
self._future.set_result(result)

async def get_result(self):
"""Wait for the result to be set and return it."""
return await self._future

@property
def finished(self) -> bool:
"""Check if the stream has finished by checking if the future is done."""
return self._future.done()


class Tracer:
"""
Recording new requests and finished requests.
Attributes: 1._request_streams: We create one stream for each request to trace the output.
2._finished_requests: A queue to store the finished requests.
3._new_requests: New requests will be stored in this queue first, before sending them to the engine.
4.new_requests_event: An event to notify the engine that there are new requests.
"""

def __init__(self) -> None:
self._request_streams: Dict[int, RequstStream] = {}
self._finished_requests: asyncio.Queue[int] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[RequstStream, dict]] = asyncio.Queue()
self.new_requests_event = None

def __contains__(self, item):
return item in self._request_streams

def init_event(self):
self.new_requests_event = asyncio.Event()

def propagate_exception(self, exc: Exception, request_id: Optional[int] = None) -> None:
"""
Propagate an exception to request streams (all if request_id is None).
"""
if request_id is not None:
self._request_streams[request_id].set_result(exc)
else:
for stream in self._request_streams.values():
stream.set_result(exc)

def process_finished_request(self, finished_request) -> None:
"""Process a finished request from the engine."""
request_id = finished_request.request_id
try:
self._request_streams[request_id].set_result(finished_request)
except:
raise RuntimeError(f"The request_id {request_id} is not found in our stream, please check")
self.abort_request(request_id)

def add_request(self, request_id: int, **engine_add_request_kwargs) -> RequstStream:
"""
Add a request to be sent to the engine on the next background
loop iteration.
"""
if request_id in self._request_streams:
raise KeyError(f"Request {request_id} already exists.")

stream = RequstStream(request_id)
logger.info(f"Added request {request_id}.")
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
self.new_requests_event.set()

return stream

def abort_request(self, request_id: int, *, verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
logger.info(f"Aborted request {request_id}.")

self._finished_requests.put_nowait(request_id)

if request_id not in self._request_streams or self._request_streams[request_id].finished:
# The request has already finished or been aborted.
# The requests in new_requests will be aborted when try to get them(if marked aborted)
return

self._request_streams[request_id].set_result(None)

def get_new_requests(self):
"""
Get new requests from http server.
"""
new_requests: List[Dict] = []
finished_requests: Set[int] = set()

while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
finished_requests.add(request_id)

while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if new_request["request_id"] in finished_requests:
# The request has been aborted.
stream.set_result(None)
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)

self.new_requests_event.clear()

return new_requests

async def wait_for_new_requests(self):
await self.new_requests_event.wait()


class _AsyncInferenceEngine(InferenceEngine):
"""
Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for
Methods: 1. async_step: The async version of Engine.step()
"""

async def async_step(self) -> List[str]:
"""
The async version of Engine.step()
Performs one decoding iteration and returns newly generated results.

It first schedules the sequences to be executed in the next iteration.
Then, it executes the model and updates the scheduler with the model
outputs. Finally, it decodes the sequences and returns the newly
generated results.
"""
batch = self.request_handler.schedule()
loop = asyncio.get_running_loop()

# Use run_in_executor to asyncally run the sync method model.forward().
logits = await loop.run_in_executor(
None,
self.model,
batch,
self.k_cache,
self.v_cache,
)

if self.inference_config.pad_input:
logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits)

finished_sequences = self.request_handler.update()
for sequence in finished_sequences:
sequence.output = self.tokenizer.decode(sequence.output_token_id)

return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0


class AsyncInferenceEngine:
"""An asynchronous wrapper for the InferenceEngine class.

This class is used to wrap the InferenceEngine class to make it asynchronous.
It uses asyncio to create a background loop that keeps processing incoming
requests. Note that this class does not hold model directly, when incoming a new
request, it first called `add_request` and the Tracer will record the request, putting
it to the background `InferenceEngine`(done in background loop) to process. You can
consider this engine as an interface.
"""

_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine

def __init__(self, start_engine_loop: bool = True, **kwargs):
self.engine = self._init_engine(**kwargs)
self.background_loop = None
# reference to the unshielded loop
self._background_loop_unshielded = None
self.start_engine_loop = start_engine_loop
self._request_tracer = Tracer()

@property
def background_loop_status(self):
return self.background_loop is not None and not self.background_loop.done()

def start_background_loop(self):
if self.background_loop_status:
raise RuntimeError("Existing loop is running")

self._request_tracer.init_event()

self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish, request_tracker=self._request_tracer)
)
self.background_loop = asyncio.shield(self._background_loop_unshielded)

def _init_engine(self, **kwargs):
return self._engine_class(**kwargs)

async def step(self):
"""
Run engine to process requests

Returns True if there are in-progress requests.
"""
new_requests = self._request_tracer.get_new_requests()
for new_request in new_requests:
self.engine.add_single_request(**new_request)
newly_finished_seqs, has_running_requests = await self.engine.async_step()

for seq in newly_finished_seqs:
self._request_tracer.process_finished_request(seq)

return has_running_requests

async def _engine_abort(self, request_ids: Iterable[int]):
self.engine.abort_request(request_ids)

async def abort(self, request_id: int):
"""
Abort a single request
"""
if not self.background_loop_status:
raise RuntimeError("Background loop is not running or launched correctly.")
return self._abort(request_id)

def _abort(self, request_id: int):
self._request_tracer.abort_request(request_id)

async def run_engine_loop(self):
processing_requests = False
while True:
if not processing_requests:
await self._request_tracer.wait_for_new_requests()
processing_requests = await self.step()
await asyncio.sleep(0)

async def add_request(
self,
request_id: int,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
) -> RequstStream:
"""
Add a request to the background tracker(waiting queue), start the background loop if needed.
"""
if not self.background_loop_status:
if self.start_engine_loop:
self.start_background_loop()
else:
raise RuntimeError("Background loop is not running.")
stream = self._request_tracer.add_request(
request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
)
return stream

async def generate(
self,
request_id: int,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
) -> AsyncIterator[str]:
"""
Generate output from a request. It receives the request from http server, adds it into the
waitting queue of Async Engine and streams the output sequence.
"""
try:
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
return await stream.get_result()

except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the request.
self._abort(request_id)
raise e
Loading
Loading