Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add advanced resolver api to get definition within scope #1442

Merged
merged 2 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion kclvm/sema/src/core/global_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ impl GlobalState {
}
}

/// get all definition symbols within specific scope
/// get all definition symbols within specific scope and parent scope
///
/// # Parameters
///
Expand Down Expand Up @@ -196,6 +196,35 @@ impl GlobalState {
Some(all_defs)
}

/// get all definition symbols within specific scope
///
/// # Parameters
///
/// `scope`: [ScopeRef]
/// the reference of scope which was allocated by [ScopeData]
///
///
/// # Returns
///
/// result: [Option<Vec<SymbolRef>>]
/// all definition symbols in the scope
pub fn get_defs_within_scope(&self, scope: ScopeRef) -> Option<Vec<SymbolRef>> {
let scopes = &self.scopes;
let scope = scopes.get_scope(&scope)?;
let all_defs: Vec<SymbolRef> = scope
.get_defs_within_scope(
scopes,
&self.symbols,
self.packages.get_module_info(scope.get_filename()),
false,
)
.values()
.into_iter()
.cloned()
.collect();
Some(all_defs)
}

/// look up closest symbol by specific position, which means
/// the specified position is located after the starting position of the returned symbol
/// and before the starting position of the next symbol
Expand Down
50 changes: 50 additions & 0 deletions kclvm/sema/src/core/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub trait Scope {
local: bool,
) -> Option<SymbolRef>;

/// Get all defs within current scope and parent scope
fn get_all_defs(
&self,
scope_data: &ScopeData,
Expand All @@ -35,6 +36,15 @@ pub trait Scope {
recursive: bool,
) -> HashMap<String, SymbolRef>;

/// Get all defs within current scope
fn get_defs_within_scope(
&self,
scope_data: &ScopeData,
symbol_data: &Self::SymbolData,
module_info: Option<&ModuleInfo>,
recursive: bool,
) -> HashMap<String, SymbolRef>;

fn dump(&self, scope_data: &ScopeData, symbol_data: &Self::SymbolData) -> Option<String>;
}

Expand Down Expand Up @@ -328,6 +338,17 @@ impl Scope for RootSymbolScope {
fn get_range(&self) -> Option<(Position, Position)> {
None
}

fn get_defs_within_scope(
&self,
scope_data: &ScopeData,
symbol_data: &Self::SymbolData,
module_info: Option<&ModuleInfo>,
recursive: bool,
) -> HashMap<String, SymbolRef> {
// get defs within root scope equal to get all defs
self.get_all_defs(scope_data, symbol_data, module_info, recursive)
}
}

impl RootSymbolScope {
Expand Down Expand Up @@ -553,6 +574,35 @@ impl Scope for LocalSymbolScope {
fn get_range(&self) -> Option<(Position, Position)> {
Some((self.start.clone(), self.end.clone()))
}

fn get_defs_within_scope(
&self,
_scope_data: &ScopeData,
symbol_data: &Self::SymbolData,
module_info: Option<&ModuleInfo>,
_recursive: bool,
) -> HashMap<String, SymbolRef> {
let mut all_defs_map = HashMap::new();
if let Some(owner) = self.owner {
if let Some(owner) = symbol_data.get_symbol(owner) {
for def_ref in owner.get_all_attributes(symbol_data, module_info) {
if let Some(def) = symbol_data.get_symbol(def_ref) {
let name = def.get_name();
if !all_defs_map.contains_key(&name) {
all_defs_map.insert(name, def_ref);
}
}
}
}
}

for def_ref in self.defs.values() {
if let Some(def) = symbol_data.get_symbol(*def_ref) {
all_defs_map.insert(def.get_name(), *def_ref);
}
}
all_defs_map
}
}

impl LocalSymbolScope {
Expand Down
87 changes: 39 additions & 48 deletions kclvm/tools/src/LSP/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ fn completion_newline(
if let ScopeKind::Local = scope.get_kind() {
if let Some(locol_scope) = gs.get_scopes().try_get_local_scope(&scope) {
if let LocalSymbolScopeKind::SchemaConfig = locol_scope.get_kind() {
if let Some(defs) = gs.get_all_defs_in_scope(scope) {
if let Some(defs) = gs.get_defs_within_scope(scope) {
for symbol_ref in defs {
match gs.get_symbols().get_symbol(symbol_ref) {
Some(def) => {
Expand Down Expand Up @@ -1341,51 +1341,6 @@ mod tests {
}
}

#[test]
fn schema_attr_newline_completion() {
let (file, program, _, gs) =
compile_test_file("src/test_data/completion_test/newline/newline.k");

let pos = KCLPos {
filename: file.to_owned(),
line: 8,
column: Some(4),
};

let tool = toolchain::default();
let mut got = completion(Some('\n'), &program, &pos, &gs, &tool).unwrap();
match &mut got {
CompletionResponse::Array(arr) => {
arr.sort_by(|a, b| a.label.cmp(&b.label));
assert_eq!(
arr[1],
CompletionItem {
label: "c".to_string(),
kind: Some(CompletionItemKind::FIELD),
detail: Some("c: int".to_string()),
documentation: None,
..Default::default()
}
)
}
CompletionResponse::List(_) => panic!("test failed"),
}

// not complete in schema stmt
let pos = KCLPos {
filename: file.to_owned(),
line: 5,
column: Some(4),
};
let got = completion(Some('\n'), &program, &pos, &gs, &tool).unwrap();
match got {
CompletionResponse::Array(arr) => {
assert!(arr.is_empty())
}
CompletionResponse::List(_) => panic!("test failed"),
}
}

#[test]
fn schema_docstring_newline_completion() {
let (file, program, _, gs) =
Expand Down Expand Up @@ -1826,7 +1781,7 @@ mod tests {
}

#[macro_export]
macro_rules! completion_label_test_snapshot {
macro_rules! completion_label_without_builtin_func_test_snapshot {
($name:ident, $file:expr, $line:expr, $column: expr, $trigger: expr) => {
#[test]
fn $name() {
Expand All @@ -1847,6 +1802,18 @@ mod tests {
let mut labels: Vec<String> =
arr.iter().map(|item| item.label.clone()).collect();
labels.sort();
let builtin_func_lables: Vec<String> = BUILTIN_FUNCTIONS
.iter()
.map(|(name, func)| {
func_ty_complete_label(name, &func.into_func_type())
})
.collect();
let labels: Vec<String> = labels
.iter()
.filter(|label| !builtin_func_lables.contains(label))
.map(|label| label.clone())
.collect();

labels
}
CompletionResponse::List(_) => panic!("test failed"),
Expand All @@ -1856,11 +1823,35 @@ mod tests {
};
}

completion_label_test_snapshot!(
completion_label_without_builtin_func_test_snapshot!(
lambda_1,
"src/test_data/completion_test/lambda/lambda_1/lambda_1.k",
8,
5,
None
);

completion_label_without_builtin_func_test_snapshot!(
schema_attr_newline_completion_0,
"src/test_data/completion_test/newline/schema/schema_0/schema_0.k",
8,
4,
Some('\n')
);

completion_label_without_builtin_func_test_snapshot!(
schema_attr_newline_completion_0_1,
"src/test_data/completion_test/newline/schema/schema_0/schema_0.k",
5,
4,
Some('\n')
);

completion_label_without_builtin_func_test_snapshot!(
schema_attr_newline_completion_1,
"src/test_data/completion_test/newline/schema/schema_1/schema_1.k",
10,
4,
Some('\n')
);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
source: tools/src/LSP/src/completion.rs
expression: "format!(\"{:?}\", labels)"
expression: "format!(\"{:?}\",\n {\n let(file, program, _, gs) =\n compile_test_file(\"src/test_data/completion_test/lambda/lambda_1/lambda_1.k\")\n ; let pos = KCLPos\n { filename : file.clone(), line : 8, column : Some(5), } ; let tool =\n toolchain :: default() ; let mut got =\n completion(None, & program, & pos, & gs, & tool).unwrap() ; match &\n mut got\n {\n CompletionResponse :: Array(arr) =>\n {\n let mut labels : Vec < String > =\n arr.iter().map(| item | item.label.clone()).collect() ;\n labels.sort() ; let builtin_func_lables : Vec < String > =\n BUILTIN_FUNCTIONS.iter().map(| (name, func) |\n {\n func_ty_complete_label(name, & func.into_func_type())\n }).collect() ; println! (\"{:?}\", builtin_func_lables) ; let\n labels : Vec < String > =\n labels.iter().filter(| label |!\n builtin_func_lables.contains(label)).map(| label |\n label.clone()).collect() ; println! (\"{:?}\", labels) ; labels\n } CompletionResponse :: List(_) => panic! (\"test failed\"),\n }\n })"
---
["abs(…)", "all_true(…)", "any_true(…)", "bin(…)", "bool(…)", "case", "cases", "dict(…)", "float(…)", "func1", "hex(…)", "int(…)", "isunique(…)", "len(…)", "list(…)", "max(…)", "min(…)", "multiplyof(…)", "oct(…)", "option(…)", "ord(…)", "pow(…)", "print(…)", "range(…)", "round(…)", "sorted(…)", "str(…)", "sum(…)", "typeof(…)", "zip(…)"]
["case", "cases", "func1"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: tools/src/LSP/src/completion.rs
expression: "format!(\"{:?}\",\n {\n let(file, program, _, gs) =\n compile_test_file(\"src/test_data/completion_test/newline/schema/schema_0/schema_0.k\")\n ; let pos = KCLPos\n { filename : file.clone(), line : 8, column : Some(4), } ; let tool =\n toolchain :: default() ; let mut got =\n completion(Some('\\n'), & program, & pos, & gs, & tool).unwrap() ;\n match & mut got\n {\n CompletionResponse :: Array(arr) =>\n {\n let mut labels : Vec < String > =\n arr.iter().map(| item | item.label.clone()).collect() ;\n labels.sort() ; let builtin_func_lables : Vec < String > =\n BUILTIN_FUNCTIONS.iter().map(| (name, func) |\n {\n func_ty_complete_label(name, & func.into_func_type())\n }).collect() ; let labels : Vec < String > =\n labels.iter().filter(| label |!\n builtin_func_lables.contains(label)).map(| label |\n label.clone()).collect() ; labels\n } CompletionResponse :: List(_) => panic! (\"test failed\"),\n }\n })"
---
["a", "c"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: tools/src/LSP/src/completion.rs
expression: "format!(\"{:?}\",\n {\n let(file, program, _, gs) =\n compile_test_file(\"src/test_data/completion_test/newline/schema/schema_0/schema_0.k\")\n ; let pos = KCLPos\n { filename : file.clone(), line : 5, column : Some(4), } ; let tool =\n toolchain :: default() ; let mut got =\n completion(Some('\\n'), & program, & pos, & gs, & tool).unwrap() ;\n match & mut got\n {\n CompletionResponse :: Array(arr) =>\n {\n let mut labels : Vec < String > =\n arr.iter().map(| item | item.label.clone()).collect() ;\n labels.sort() ; let builtin_func_lables : Vec < String > =\n BUILTIN_FUNCTIONS.iter().map(| (name, func) |\n {\n func_ty_complete_label(name, & func.into_func_type())\n }).collect() ; let labels : Vec < String > =\n labels.iter().filter(| label |!\n builtin_func_lables.contains(label)).map(| label |\n label.clone()).collect() ; labels\n } CompletionResponse :: List(_) => panic! (\"test failed\"),\n }\n })"
---
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: tools/src/LSP/src/completion.rs
expression: "format!(\"{:?}\",\n {\n let(file, program, _, gs) =\n compile_test_file(\"src/test_data/completion_test/newline/schema/schema_1/schema_1.k\")\n ; let pos = KCLPos\n { filename : file.clone(), line : 10, column : Some(4), } ; let tool =\n toolchain :: default() ; let mut got =\n completion(None, & program, & pos, & gs, & tool).unwrap() ; match &\n mut got\n {\n CompletionResponse :: Array(arr) =>\n {\n let mut labels : Vec < String > =\n arr.iter().map(| item | item.label.clone()).collect() ;\n labels.sort() ; let builtin_func_lables : Vec < String > =\n BUILTIN_FUNCTIONS.iter().map(| (name, func) |\n {\n func_ty_complete_label(name, & func.into_func_type())\n }).collect() ; println! (\"{:?}\", builtin_func_lables) ; let\n labels : Vec < String > =\n labels.iter().filter(| label |!\n builtin_func_lables.contains(label)).map(| label |\n label.clone()).collect() ; println! (\"{:?}\", labels) ; labels\n } CompletionResponse :: List(_) => panic! (\"test failed\"),\n }\n })"
---
["name"]
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ schema Person[b: int](Base):
c: int

p1= Person(b){

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
schema Name:
name: str

schema Person:
name: Name
age: int

p = Person{
name: Name{

}
}

Loading