diff --git a/src/lmql/ops/inline_call.py b/src/lmql/ops/inline_call.py index 9ad5eb5b..920d37e9 100644 --- a/src/lmql/ops/inline_call.py +++ b/src/lmql/ops/inline_call.py @@ -2,8 +2,12 @@ from .node import * from .booleans import * import inspect +from copy import deepcopy class InlineCallOp(Node): + """ + Represents a nested query call e.g. in "[A: fct()]". + """ def __init__(self, predecessors, lcls, glbs): super().__init__(predecessors) fct, args = predecessors @@ -34,7 +38,7 @@ def __init__(self, predecessors, lcls, glbs): # set positional arguments assert fct.function_context is not None, "LMQL in-context function " + str(fct) + " has no function context." - self.query_kwargs = None + self.query_kwargs = self.derive_query_kwargs(predecessors[1]) def execute_predecessors(self, trace, context): fct, args = self.predecessors @@ -57,19 +61,27 @@ def forward(self, *args, **kwargs): return None return context.subinterpreter_results[si] + def derive_query_kwargs(self, args): + args = args[1:] + + # get python values of variables when available + args = [a.python_value if isinstance(a, Var) and a.python_variable else a for a in args] + + def is_kwarg(a): + return type(a) is tuple and len(a) == 2 and type(a[0]) is str and a[0].startswith("__kw:") + kwargs = {k[0][len("__kw:"):]: k[1] for k in args if is_kwarg(k)} + args = [a for a in args if not is_kwarg(a)] + + if self.self_binding is not None: + kwargs["__self__"] = self.self_binding + query_kwargs, _ = self.query_fct.make_kwargs(*args, **kwargs) + + return query_kwargs + def subinterpreter(self, runtime, prompt, args = None): query_kwargs = None if args is not None: - args = args[1:] - def is_kwarg(a): - return type(a) is tuple and len(a) == 2 and type(a[0]) is str and a[0].startswith("__kw:") - kwargs = {k[0][len("__kw:"):]: k[1] for k in args if is_kwarg(k)} - args = [a for a in args if not is_kwarg(a)] - - if self.self_binding is not None: - kwargs["__self__"] = self.self_binding - query_kwargs, _ = self.query_fct.make_kwargs(*args, **kwargs) - self.query_kwargs = query_kwargs + self.query_kwargs = self.derive_query_kwargs(args) else: if self.query_kwargs is None: return None @@ -97,6 +109,7 @@ def postprocess(self, operands, value): runtime = context.runtime # instructs postprocessing logic to use subinterpreters results as postprocessing results si = self.subinterpreter(runtime, context.prompt, args) + assert si.root_state is not None, "subinterpreter result is not initialized: {} (this is very likely a bug, please report it in the issue tracker of the project)".format(si) return si def postprocess_order(self, other, operands, other_inputs, **kwargs): diff --git a/src/lmql/runtime/interpreter.py b/src/lmql/runtime/interpreter.py index 38a7d538..487fa1cb 100644 --- a/src/lmql/runtime/interpreter.py +++ b/src/lmql/runtime/interpreter.py @@ -1258,8 +1258,6 @@ async def subinterpreter_results(self, s: dc.DecoderSequence, variable, text, di for ic in inline_calls: si = ic.subinterpreter(self, calling_state.prompt) - if si is None: continue - subinterpreters.append(si) # prepare subinterpreter if this is the first time it is used diff --git a/src/lmql/tests/test_nested_back2back.py b/src/lmql/tests/test_nested_back2back.py new file mode 100644 index 00000000..7830561b --- /dev/null +++ b/src/lmql/tests/test_nested_back2back.py @@ -0,0 +1,65 @@ +import lmql +from lmql.tests.expr_test_utils import run_all_tests + +nl = "\n" + +class Output: + def __init__(self): + self.output_vars = {} + + def add(self, vvar, value): + if vvar in self.output_vars: + self.output_vars[vvar].append(value) + else: + self.output_vars[vvar] = [value] + + def add_all(self, vvar, values): + if vvar in self.output_vars: + self.output_vars[vvar].extend(values) + else: + self.output_vars[vvar] = list(values) + + +@lmql.query +async def make_summary_option(o, option_text, length_limit, option_type): + '''lmql + offset = len(context.prompt) + + "{:user}Make a {option_text}" + "{:assistant}[option]" where STOPS_AT(option, nl) and STOPS_AT(option, ".") + if len(option) > length_limit: + "{:user}The {option_text} is too long, make it shorter" + "{:assistant}[option]" where STOPS_AT(option, nl) and STOPS_AT(option, ".") + o.add(option_type, option.strip()) + ''' + +@lmql.query(model=lmql.model("local:llama.cpp:/lmql/llama-2-7b-chat.Q2_K.gguf", tokenizer="AyyYOO/Luna-AI-Llama2-Uncensored-FP16-sharded")) +async def test_summary_back2back(): + '''lmql + o = Output() + material_text = "This is a story about love." + + n_summaries = 2 + n_distractors = 2 + + """{:system}The user wants to make an activity for the instructions `Select the sentences that are true in the text`. + To do it they will ask you for `correct answers` or `distractors` for a particular text. + Here is the text: + + {material_text} + + If you are asked for a correct answer more than once make them different. + Only answer with the answer or distractor, nothing more. + Your responses should be shorter than 100 characters. + Avoid repetitive language. + """ + length_limit = 120 + for i in range(n_summaries): + '[correct_summaries: make_summary_option(o, "correct answer", length_limit, "correct_summaries")]' + + for i in range(n_distractors): + '[incorrect_summaries: make_summary_option(o, "distractor", length_limit, "incorrect_summaries")]' + ''' + +if __name__ == "__main__": + run_all_tests(globals()) \ No newline at end of file diff --git a/src/lmql/tests/test_stopping.py b/src/lmql/tests/test_stopping.py index 06d47df9..82b15dff 100644 --- a/src/lmql/tests/test_stopping.py +++ b/src/lmql/tests/test_stopping.py @@ -64,7 +64,7 @@ async def test_conditional_stopping(): from lmql.model("local:llama.cpp:/lmql/llama-2-7b-chat.Q2_K.gguf", tokenizer="AyyYOO/Luna-AI-Llama2-Uncensored-FP16-sharded") where - len(TOKENS(OUTPUT)) > 80 and STOPS_AT(OUTPUT, "review") + len(TOKENS(OUTPUT)) > 25 and STOPS_AT(OUTPUT, "review") ''' @lmql.query