Skip to content

Commit

Permalink
enable generate.fsm with llamacpp by using outlines.processors
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Jun 30, 2024
1 parent 128f7e6 commit 5c7c546
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 18 deletions.
19 changes: 18 additions & 1 deletion outlines/generate/fsm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,31 @@
from functools import singledispatch

import interegular

from outlines.fsm.guide import RegexGuide
from outlines.generate.api import SequenceGenerator
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import MLXLM, LlamaCpp, Transformers
from outlines.samplers import Sampler, multinomial


@singledispatch
def fsm(
model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()
) -> SequenceGenerator:
fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)
return generator


@fsm.register(MLXLM)
@fsm.register(Transformers)
@fsm.register(LlamaCpp)
def fsm_unified(
model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()
) -> SequenceGeneratorAdapter:
from outlines.processors import FSMLogitsProcessor

fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
logits_processor = FSMLogitsProcessor(tokenizer=model.tokenizer, fsm=fsm)
return SequenceGeneratorAdapter(model, logits_processor, sampler)
13 changes: 1 addition & 12 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):

@regex.register(MLXLM)
@regex.register(Transformers)
@regex.register(LlamaCpp)
def regex_unified(
model,
regex_str: str,
Expand All @@ -52,18 +53,6 @@ def regex_unified(
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(LlamaCpp)
def regex_llamacpp(
model: LlamaCpp,
regex_str: str,
sampler: Sampler = multinomial(),
):
from outlines.integrations.llamacpp import RegexLogitsProcessor

logits_processor = RegexLogitsProcessor(regex_str, llm=model.model)
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(VLLM)
def regex_vllm(
model: VLLM,
Expand Down
6 changes: 1 addition & 5 deletions outlines/generate/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:

@text.register(MLXLM)
@text.register(Transformers)
@text.register(LlamaCpp)
def text_unified(model, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)

Expand All @@ -47,11 +48,6 @@ def text_vllm(model: VLLM, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)


@text.register(LlamaCpp)
def text_llamacpp(model: LlamaCpp, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)


@text.register(OpenAI)
def text_openai(model: OpenAI, sampler: Sampler = multinomial()) -> OpenAI:
if not isinstance(sampler, multinomial):
Expand Down
4 changes: 4 additions & 0 deletions outlines/models/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ class LlamaCpp:
def __init__(self, model: "Llama"):
self.model = model

@property
def tokenizer(self):
return LlamaCppTokenizer(self.model)

def prepare_generation_parameters(
self,
generation_parameters: GenerationParameters,
Expand Down
11 changes: 11 additions & 0 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ def test_generate_text_stream(request, model_fixture):
assert isinstance(token, str)


@pytest.mark.parametrize("pattern", REGEX_PATTERNS)
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_fsm(request, model_fixture, pattern):
import interegular

model = request.getfixturevalue(model_fixture)
generator = generate.fsm(model, interegular.parse_pattern(pattern).to_fsm())
res = generator("test")
assert re.fullmatch(pattern, res) is not None, res


@pytest.mark.parametrize("pattern", REGEX_PATTERNS)
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_regex(request, model_fixture, pattern):
Expand Down

0 comments on commit 5c7c546

Please sign in to comment.