Add tree-sitter injections support & cleanup

This commit is contained in:
2025-12-25 04:14:53 +00:00
parent a10dd92249
commit 659628835d
16 changed files with 302 additions and 323 deletions

162
src/ts.cc
View File

@@ -1,6 +1,7 @@
#include "../include/ts.h"
#include "../include/editor.h"
#include "../include/knot.h"
#include "../include/maps.h"
#include <algorithm>
#include <cstdint>
#include <fstream>
@@ -28,8 +29,8 @@ pcre2_code *get_re(const std::string &pattern) {
return re;
}
TSQuery *load_query(const char *query_path, Editor *editor) {
const TSLanguage *lang = editor->language;
TSQuery *load_query(const char *query_path, TSSetBase *set) {
const TSLanguage *lang = set->language;
std::ifstream file(query_path, std::ios::in | std::ios::binary);
if (!file.is_open())
return nullptr;
@@ -38,7 +39,7 @@ TSQuery *load_query(const char *query_path, Editor *editor) {
int errornumber = 0;
PCRE2_SIZE erroroffset = 0;
pcre2_code *re = pcre2_compile(
(PCRE2_SPTR) R"((@[A-Za-z0-9_.]+)|(;; \#[0-9a-fA-F]{6} \#[0-9a-fA-F]{6} [01] [01] [01] \d+))",
(PCRE2_SPTR) R"((@[A-Za-z0-9_.]+)|(;; \#[0-9a-fA-F]{6} \#[0-9a-fA-F]{6} [01] [01] [01] \d+)|(;; !(\w+)))",
PCRE2_ZERO_TERMINATED, 0, &errornumber, &erroroffset, nullptr);
if (!re)
return nullptr;
@@ -46,9 +47,8 @@ TSQuery *load_query(const char *query_path, Editor *editor) {
pcre2_match_data_create_from_pattern(re, nullptr);
std::map<std::string, int> capture_name_cache;
Highlight *c_hl = nullptr;
Language c_lang = {"unknown", nullptr, 0};
int i = 0;
int limit = 20;
editor->query_map.resize(limit);
PCRE2_SIZE offset = 0;
PCRE2_SIZE subject_length = highlight_query.size();
while (offset < subject_length) {
@@ -63,18 +63,18 @@ TSQuery *load_query(const char *query_path, Editor *editor) {
std::string capture_name = mct;
if (!capture_name_cache.count(capture_name)) {
if (c_hl) {
if (i >= limit) {
limit += 20;
editor->query_map.resize(limit);
}
editor->query_map[i] = *c_hl;
set->query_map[i] = *c_hl;
delete c_hl;
c_hl = nullptr;
}
if (c_lang.fn != nullptr) {
set->injection_map[i] = c_lang;
c_lang = {"unknown", nullptr, 0};
}
capture_name_cache[capture_name] = i;
i++;
}
} else if (mct.size() >= 2 && mct[0] == ';' && mct[1] == ';') {
} else if (mct.substr(0, 4) == ";; #") {
if (c_hl)
delete c_hl;
c_hl = new Highlight();
@@ -86,6 +86,10 @@ TSQuery *load_query(const char *query_path, Editor *editor) {
c_hl->priority = std::stoi(mct.substr(25));
c_hl->flags = (bold ? CF_BOLD : 0) | (italic ? CF_ITALIC : 0) |
(underline ? CF_UNDERLINE : 0);
} else if (mct.substr(0, 4) == ";; !") {
auto it = kLanguages.find(mct.substr(4));
if (it != kLanguages.end())
c_lang = it->second;
}
offset = ovector[1];
}
@@ -174,26 +178,32 @@ const char *read_ts(void *payload, uint32_t byte_index, TSPoint,
return leaf_from_offset(editor->root, byte_index, bytes_read);
}
static inline Highlight *safe_get(std::vector<Highlight> &vec, size_t index) {
if (index >= vec.size())
template <typename T>
static inline T *safe_get(std::map<uint16_t, T> &m, uint16_t key) {
auto it = m.find(key);
if (it == m.end())
return nullptr;
return &vec[index];
return &it->second;
}
void ts_collect_spans(Editor *editor) {
static int parse_counter = 0;
if (!editor->parser || !editor->root || !editor->query)
if (!editor->ts.parser || !editor->root || !editor->ts.query)
return;
TSInput tsinput = {
const bool injections_enabled = editor->root->char_count < (1024 * 32);
for (auto &inj : editor->ts.injections)
inj.ranges.clear();
TSInput tsinput{
.payload = editor,
.read = read_ts,
.encoding = TSInputEncodingUTF8,
.decode = nullptr,
};
TSTree *tree, *copy = nullptr;
TSTree *tree = nullptr;
TSTree *copy = nullptr;
std::unique_lock knot_mtx(editor->knot_mtx);
if (editor->tree)
copy = ts_tree_copy(editor->tree);
if (editor->ts.tree)
copy = ts_tree_copy(editor->ts.tree);
knot_mtx.unlock();
std::vector<TSInputEdit> edits;
TSInputEdit edit;
@@ -201,7 +211,7 @@ void ts_collect_spans(Editor *editor) {
while (editor->edit_queue.pop(edit)) {
edits.push_back(edit);
ts_tree_edit(copy, &edits.back());
};
}
if (copy && edits.empty() && parse_counter < 64) {
parse_counter++;
ts_tree_delete(copy);
@@ -210,41 +220,129 @@ void ts_collect_spans(Editor *editor) {
parse_counter = 0;
editor->spans.mid_parse = true;
std::shared_lock lock(editor->knot_mtx);
tree = ts_parser_parse(editor->parser, copy, tsinput);
tree = ts_parser_parse(editor->ts.parser, copy, tsinput);
lock.unlock();
if (copy)
ts_tree_delete(copy);
knot_mtx.lock();
if (editor->tree)
ts_tree_delete(editor->tree);
editor->tree = tree;
if (editor->ts.tree)
ts_tree_delete(editor->ts.tree);
editor->ts.tree = tree;
copy = ts_tree_copy(tree);
knot_mtx.unlock();
std::unordered_map<std::string, TSSet *> inj_lookup;
for (auto &inj : editor->ts.injections)
if (inj.lang != "unknown")
inj_lookup[inj.lang] = &inj;
TSQueryCursor *cursor = ts_query_cursor_new();
ts_query_cursor_exec(cursor, editor->query, ts_tree_root_node(copy));
ts_query_cursor_exec(cursor, editor->ts.query, ts_tree_root_node(copy));
std::vector<Span> new_spans;
new_spans.reserve(4096);
struct PendingRanges {
std::vector<TSRange> ranges;
TSSet *tsset = nullptr;
};
std::unordered_map<std::string, PendingRanges> pending_injections;
TSQueryMatch match;
while (ts_query_cursor_next_match(cursor, &match)) {
if (!ts_predicate(editor->query, match, editor->root))
if (!ts_predicate(editor->ts.query, match, editor->root))
continue;
for (uint32_t i = 0; i < match.capture_count; i++) {
TSQueryCapture cap = match.captures[i];
uint32_t start = ts_node_start_byte(cap.node);
uint32_t end = ts_node_end_byte(cap.node);
Highlight *hl = safe_get(editor->query_map, cap.index);
if (hl)
if (Highlight *hl = safe_get(editor->ts.query_map, cap.index))
new_spans.push_back({start, end, hl});
if (!injections_enabled)
continue;
if (Language *inj_lang = safe_get(editor->ts.injection_map, cap.index)) {
auto &pending = pending_injections[inj_lang->name];
if (!pending.tsset) {
if (auto it = inj_lookup.find(inj_lang->name);
it != inj_lookup.end()) {
pending.tsset = it->second;
} else {
TSSet fresh{};
fresh.lang = inj_lang->name;
fresh.parser = ts_parser_new();
ts_parser_set_language(fresh.parser, inj_lang->fn());
fresh.language = inj_lang->fn();
fresh.query_file =
get_exe_dir() + "/../grammar/" + inj_lang->name + ".scm";
fresh.query = load_query(fresh.query_file.c_str(), &fresh);
editor->ts.injections.push_back(std::move(fresh));
pending.tsset = &editor->ts.injections.back();
inj_lookup[inj_lang->name] = pending.tsset;
}
}
pending.ranges.push_back(TSRange{
ts_node_start_point(cap.node),
ts_node_end_point(cap.node),
start,
end,
});
}
}
}
auto overlaps = [](const Span &s, const TSRange &r) {
return !(s.end <= r.start_byte || s.start >= r.end_byte);
};
if (injections_enabled) {
for (auto &[lang_name, pending] : pending_injections) {
TSSet *tsset = pending.tsset;
if (!tsset)
continue;
tsset->ranges = std::move(pending.ranges);
if (tsset->ranges.size() > 1)
new_spans.erase(std::remove_if(new_spans.begin(), new_spans.end(),
[&](const Span &sp) {
return std::any_of(
tsset->ranges.begin(),
tsset->ranges.end(),
[&](const TSRange &r) {
return overlaps(sp, r);
});
}),
new_spans.end());
}
for (auto &inj : editor->ts.injections) {
if (!inj.parser || !inj.query || inj.ranges.size() == 0)
continue;
ts_parser_set_included_ranges(inj.parser, inj.ranges.data(),
inj.ranges.size());
std::pair<uint32_t, int64_t> span_edit;
while (editor->spans.edits.pop(span_edit))
apply_edit(new_spans, span_edit.first, span_edit.second);
knot_mtx.lock();
TSTree *inj_tree = ts_parser_parse(inj.parser, inj.tree, tsinput);
knot_mtx.unlock();
if (inj.tree)
ts_tree_delete(inj.tree);
inj.tree = inj_tree;
TSTree *inj_copy = ts_tree_copy(inj_tree);
TSQueryCursor *inj_cursor = ts_query_cursor_new();
ts_query_cursor_exec(inj_cursor, inj.query, ts_tree_root_node(inj_copy));
TSQueryMatch inj_match;
while (ts_query_cursor_next_match(inj_cursor, &inj_match)) {
if (!ts_predicate(inj.query, inj_match, editor->root))
continue;
for (uint32_t i = 0; i < inj_match.capture_count; i++) {
TSQueryCapture cap = inj_match.captures[i];
uint32_t start = ts_node_start_byte(cap.node);
uint32_t end = ts_node_end_byte(cap.node);
if (Highlight *hl = safe_get(inj.query_map, cap.index))
new_spans.push_back({start, end, hl});
}
}
ts_query_cursor_delete(inj_cursor);
ts_tree_delete(inj_copy);
}
}
ts_query_cursor_delete(cursor);
ts_tree_delete(copy);
std::sort(new_spans.begin(), new_spans.end());
std::pair<uint32_t, int64_t> span_edit;
while (editor->spans.edits.pop(span_edit))
apply_edit(new_spans, span_edit.first, span_edit.second);
std::sort(new_spans.begin(), new_spans.end());
std::unique_lock span_mtx(editor->spans.mtx);
editor->spans.mid_parse = false;
editor->spans.spans.swap(new_spans);
span_mtx.unlock();
}