Skip to content

Commit

Permalink
fix bug with looped nested queries
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeurerkellner committed Mar 8, 2024
1 parent bd1e7ce commit 5c9d802
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 14 deletions.
35 changes: 24 additions & 11 deletions src/lmql/ops/inline_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
from .node import *
from .booleans import *
import inspect
from copy import deepcopy

class InlineCallOp(Node):
"""
Represents a nested query call e.g. in "[A: fct()]".
"""
def __init__(self, predecessors, lcls, glbs):
super().__init__(predecessors)
fct, args = predecessors
Expand Down Expand Up @@ -34,7 +38,7 @@ def __init__(self, predecessors, lcls, glbs):

# set positional arguments
assert fct.function_context is not None, "LMQL in-context function " + str(fct) + " has no function context."
self.query_kwargs = None
self.query_kwargs = self.derive_query_kwargs(predecessors[1])

def execute_predecessors(self, trace, context):
fct, args = self.predecessors
Expand All @@ -57,19 +61,27 @@ def forward(self, *args, **kwargs):
return None
return context.subinterpreter_results[si]

def derive_query_kwargs(self, args):
args = args[1:]

# get python values of variables when available
args = [a.python_value if isinstance(a, Var) and a.python_variable else a for a in args]

def is_kwarg(a):
return type(a) is tuple and len(a) == 2 and type(a[0]) is str and a[0].startswith("__kw:")
kwargs = {k[0][len("__kw:"):]: k[1] for k in args if is_kwarg(k)}
args = [a for a in args if not is_kwarg(a)]

if self.self_binding is not None:
kwargs["__self__"] = self.self_binding
query_kwargs, _ = self.query_fct.make_kwargs(*args, **kwargs)

return query_kwargs

def subinterpreter(self, runtime, prompt, args = None):
query_kwargs = None
if args is not None:
args = args[1:]
def is_kwarg(a):
return type(a) is tuple and len(a) == 2 and type(a[0]) is str and a[0].startswith("__kw:")
kwargs = {k[0][len("__kw:"):]: k[1] for k in args if is_kwarg(k)}
args = [a for a in args if not is_kwarg(a)]

if self.self_binding is not None:
kwargs["__self__"] = self.self_binding
query_kwargs, _ = self.query_fct.make_kwargs(*args, **kwargs)
self.query_kwargs = query_kwargs
self.query_kwargs = self.derive_query_kwargs(args)
else:
if self.query_kwargs is None:
return None
Expand Down Expand Up @@ -97,6 +109,7 @@ def postprocess(self, operands, value):
runtime = context.runtime
# instructs postprocessing logic to use subinterpreters results as postprocessing results
si = self.subinterpreter(runtime, context.prompt, args)
assert si.root_state is not None, "subinterpreter result is not initialized: {} (this is very likely a bug, please report it in the issue tracker of the project)".format(si)
return si

def postprocess_order(self, other, operands, other_inputs, **kwargs):
Expand Down
2 changes: 0 additions & 2 deletions src/lmql/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,8 +1258,6 @@ async def subinterpreter_results(self, s: dc.DecoderSequence, variable, text, di

for ic in inline_calls:
si = ic.subinterpreter(self, calling_state.prompt)
if si is None: continue

subinterpreters.append(si)

# prepare subinterpreter if this is the first time it is used
Expand Down
65 changes: 65 additions & 0 deletions src/lmql/tests/test_nested_back2back.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import lmql
from lmql.tests.expr_test_utils import run_all_tests

nl = "\n"

class Output:
def __init__(self):
self.output_vars = {}

def add(self, vvar, value):
if vvar in self.output_vars:
self.output_vars[vvar].append(value)
else:
self.output_vars[vvar] = [value]

def add_all(self, vvar, values):
if vvar in self.output_vars:
self.output_vars[vvar].extend(values)
else:
self.output_vars[vvar] = list(values)


@lmql.query
async def make_summary_option(o, option_text, length_limit, option_type):
'''lmql
offset = len(context.prompt)
"{:user}Make a {option_text}"
"{:assistant}[option]" where STOPS_AT(option, nl) and STOPS_AT(option, ".")
if len(option) > length_limit:
"{:user}The {option_text} is too long, make it shorter"
"{:assistant}[option]" where STOPS_AT(option, nl) and STOPS_AT(option, ".")
o.add(option_type, option.strip())
'''

@lmql.query(model=lmql.model("local:llama.cpp:/lmql/llama-2-7b-chat.Q2_K.gguf", tokenizer="AyyYOO/Luna-AI-Llama2-Uncensored-FP16-sharded"))
async def test_summary_back2back():
'''lmql
o = Output()
material_text = "This is a story about love."
n_summaries = 2
n_distractors = 2
"""{:system}The user wants to make an activity for the instructions `Select the sentences that are true in the text`.
To do it they will ask you for `correct answers` or `distractors` for a particular text.
Here is the text:
{material_text}
If you are asked for a correct answer more than once make them different.
Only answer with the answer or distractor, nothing more.
Your responses should be shorter than 100 characters.
Avoid repetitive language.
"""
length_limit = 120
for i in range(n_summaries):
'[correct_summaries: make_summary_option(o, "correct answer", length_limit, "correct_summaries")]'
for i in range(n_distractors):
'[incorrect_summaries: make_summary_option(o, "distractor", length_limit, "incorrect_summaries")]'
'''

if __name__ == "__main__":
run_all_tests(globals())
2 changes: 1 addition & 1 deletion src/lmql/tests/test_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def test_conditional_stopping():
from
lmql.model("local:llama.cpp:/lmql/llama-2-7b-chat.Q2_K.gguf", tokenizer="AyyYOO/Luna-AI-Llama2-Uncensored-FP16-sharded")
where
len(TOKENS(OUTPUT)) > 80 and STOPS_AT(OUTPUT, "review")
len(TOKENS(OUTPUT)) > 25 and STOPS_AT(OUTPUT, "review")
'''

@lmql.query
Expand Down

0 comments on commit 5c9d802

Please sign in to comment.