From 69b981097e493f2ce128735b12c148e181eab415 Mon Sep 17 00:00:00 2001 From: Syed Daanish Date: Thu, 25 Dec 2025 04:55:56 +0000 Subject: [PATCH] Improve syntax highlight injection logic --- grammar/ruby.scm | 20 ++++++++++++++----- include/editor.h | 4 ++-- samples/ruby.rb | 4 ++-- src/editor.cc | 10 ++++------ src/ts.cc | 52 ++++++++++++++++++------------------------------ 5 files changed, 42 insertions(+), 48 deletions(-) diff --git a/grammar/ruby.scm b/grammar/ruby.scm index 6520ae2..d38aa34 100644 --- a/grammar/ruby.scm +++ b/grammar/ruby.scm @@ -1,6 +1,16 @@ -; This is an injection test - it should hight all heredoc content as bash code -;; !bash - this part should be ignored (anything after the first wordbreak after the `!`) -(heredoc_content) @ruby_injection +(heredoc_body +;; !bash + (heredoc_content) @bash_injection + ((heredoc_end) @lang + (#match? @lang "BASH")) +) + +(heredoc_body +;; !ruby + (heredoc_content) @ruby_injection + ((heredoc_end) @lang + (#match? @lang "RUBY")) +) ;; #ffffff #000000 0 0 0 1 [ @@ -147,10 +157,10 @@ ((call !receiver method: (identifier) @function.builtin) - (#match? @function.builtin "^(include|extend|prepend|refine|using)")) + (#match? @function.builtin "^(include|extend|prepend|refine|using)$")) ((identifier) @keyword.exception - (#match? @keyword.exception "^(raise|fail|catch|throw)" )) + (#match? @keyword.exception "^(raise|fail|catch|throw)$" )) ;; #ffffff #000000 0 0 0 1 [ diff --git a/include/editor.h b/include/editor.h index b020298..eac69af 100644 --- a/include/editor.h +++ b/include/editor.h @@ -121,7 +121,6 @@ struct VAI { struct TSSetBase { std::string lang; - TSTree *tree; TSParser *parser; std::string query_file; TSQuery *query; @@ -135,7 +134,8 @@ struct TSSet : TSSetBase { }; struct TSSetMain : TSSetBase { - std::vector injections; + TSTree *tree; + std::unordered_map injections; }; struct Editor { diff --git a/samples/ruby.rb b/samples/ruby.rb index cb9945d..2bc5286 100644 --- a/samples/ruby.rb +++ b/samples/ruby.rb @@ -43,7 +43,7 @@ end puts "Emoji count: #{emojis.length}" # Multi-line string with unicode -multi = <<~EOF +multi = <<~BASH # Function recursion demo factorial() { local n="$1" @@ -57,7 +57,7 @@ multi = <<~EOF } log INFO "factorial(5) = $(factorial 5)" -EOF +BASH puts multi diff --git a/src/editor.cc b/src/editor.cc index 0ffa372..cc3aa98 100644 --- a/src/editor.cc +++ b/src/editor.cc @@ -43,12 +43,10 @@ void free_tsset(TSSetMain *set) { if (set->query) ts_query_delete(set->query); for (auto &inj : set->injections) { - if (inj.parser) - ts_parser_delete(inj.parser); - if (inj.tree) - ts_tree_delete(inj.tree); - if (inj.query) - ts_query_delete(inj.query); + if (inj.second.parser) + ts_parser_delete(inj.second.parser); + if (inj.second.query) + ts_query_delete(inj.second.query); } } diff --git a/src/ts.cc b/src/ts.cc index 2f7bd48..4e755e7 100644 --- a/src/ts.cc +++ b/src/ts.cc @@ -192,7 +192,7 @@ void ts_collect_spans(Editor *editor) { return; const bool injections_enabled = editor->root->char_count < (1024 * 32); for (auto &inj : editor->ts.injections) - inj.ranges.clear(); + inj.second.ranges.clear(); TSInput tsinput{ .payload = editor, .read = read_ts, @@ -228,10 +228,6 @@ void ts_collect_spans(Editor *editor) { ts_tree_delete(editor->ts.tree); editor->ts.tree = tree; copy = ts_tree_copy(tree); - std::unordered_map inj_lookup; - for (auto &inj : editor->ts.injections) - if (inj.lang != "unknown") - inj_lookup[inj.lang] = &inj; TSQueryCursor *cursor = ts_query_cursor_new(); ts_query_cursor_exec(cursor, editor->ts.query, ts_tree_root_node(copy)); std::vector new_spans; @@ -255,24 +251,18 @@ void ts_collect_spans(Editor *editor) { continue; if (Language *inj_lang = safe_get(editor->ts.injection_map, cap.index)) { auto &pending = pending_injections[inj_lang->name]; - if (!pending.tsset) { - if (auto it = inj_lookup.find(inj_lang->name); - it != inj_lookup.end()) { - pending.tsset = it->second; - } else { - TSSet fresh{}; - fresh.lang = inj_lang->name; - fresh.parser = ts_parser_new(); - ts_parser_set_language(fresh.parser, inj_lang->fn()); - fresh.language = inj_lang->fn(); - fresh.query_file = - get_exe_dir() + "/../grammar/" + inj_lang->name + ".scm"; - fresh.query = load_query(fresh.query_file.c_str(), &fresh); - editor->ts.injections.push_back(std::move(fresh)); - pending.tsset = &editor->ts.injections.back(); - inj_lookup[inj_lang->name] = pending.tsset; - } + TSSet &tsset = + editor->ts.injections.try_emplace(inj_lang->name).first->second; + if (!tsset.parser) { + tsset.lang = inj_lang->name; + tsset.parser = ts_parser_new(); + ts_parser_set_language(tsset.parser, inj_lang->fn()); + tsset.language = inj_lang->fn(); + tsset.query_file = + get_exe_dir() + "/../grammar/" + inj_lang->name + ".scm"; + tsset.query = load_query(tsset.query_file.c_str(), &tsset); } + pending.tsset = &tsset; pending.ranges.push_back(TSRange{ ts_node_start_point(cap.node), ts_node_end_point(cap.node), @@ -291,7 +281,7 @@ void ts_collect_spans(Editor *editor) { if (!tsset) continue; tsset->ranges = std::move(pending.ranges); - if (tsset->ranges.size() > 1) + if (tsset->ranges.size() > 0) new_spans.erase(std::remove_if(new_spans.begin(), new_spans.end(), [&](const Span &sp) { return std::any_of( @@ -303,23 +293,19 @@ void ts_collect_spans(Editor *editor) { }), new_spans.end()); } - for (auto &inj : editor->ts.injections) { + for (auto &kv : editor->ts.injections) { + auto &inj = kv.second; if (!inj.parser || !inj.query || inj.ranges.size() == 0) continue; ts_parser_set_included_ranges(inj.parser, inj.ranges.data(), inj.ranges.size()); + knot_mtx.lock(); std::pair span_edit; while (editor->spans.edits.pop(span_edit)) apply_edit(new_spans, span_edit.first, span_edit.second); - knot_mtx.lock(); - TSTree *inj_tree = ts_parser_parse(inj.parser, inj.tree, tsinput); - knot_mtx.unlock(); - if (inj.tree) - ts_tree_delete(inj.tree); - inj.tree = inj_tree; - TSTree *inj_copy = ts_tree_copy(inj_tree); + TSTree *inj_tree = ts_parser_parse(inj.parser, nullptr, tsinput); TSQueryCursor *inj_cursor = ts_query_cursor_new(); - ts_query_cursor_exec(inj_cursor, inj.query, ts_tree_root_node(inj_copy)); + ts_query_cursor_exec(inj_cursor, inj.query, ts_tree_root_node(inj_tree)); TSQueryMatch inj_match; while (ts_query_cursor_next_match(inj_cursor, &inj_match)) { if (!ts_predicate(inj.query, inj_match, editor->root)) @@ -333,7 +319,7 @@ void ts_collect_spans(Editor *editor) { } } ts_query_cursor_delete(inj_cursor); - ts_tree_delete(inj_copy); + ts_tree_delete(inj_tree); } } ts_query_cursor_delete(cursor);