Files
crib/src/ts.cc

346 lines
12 KiB
C++

#include "../include/ts.h"
#include "../include/editor.h"
#include "../include/knot.h"
#include "../include/maps.h"
#include <algorithm>
#include <cstdint>
#include <fstream>
#include <functional>
#include <string>
#include <unordered_map>
std::unordered_map<std::string, pcre2_code *> regex_cache;
void clear_regex_cache() {
for (auto &kv : regex_cache)
pcre2_code_free(kv.second);
regex_cache.clear();
}
pcre2_code *get_re(const std::string &pattern) {
auto it = regex_cache.find(pattern);
if (it != regex_cache.end())
return it->second;
int errornum;
PCRE2_SIZE erroffset;
pcre2_code *re =
pcre2_compile((PCRE2_SPTR)pattern.c_str(), PCRE2_ZERO_TERMINATED, 0,
&errornum, &erroffset, nullptr);
regex_cache[pattern] = re;
return re;
}
TSQuery *load_query(const char *query_path, TSSetBase *set) {
const TSLanguage *lang = set->language;
std::ifstream file(query_path, std::ios::in | std::ios::binary);
if (!file.is_open())
return nullptr;
std::string highlight_query((std::istreambuf_iterator<char>(file)),
std::istreambuf_iterator<char>());
int errornumber = 0;
PCRE2_SIZE erroroffset = 0;
pcre2_code *re = pcre2_compile(
(PCRE2_SPTR) R"((@[A-Za-z0-9_.]+)|(;; \#[0-9a-fA-F]{6} \#[0-9a-fA-F]{6} [01] [01] [01] \d+)|(;; !(\w+)))",
PCRE2_ZERO_TERMINATED, 0, &errornumber, &erroroffset, nullptr);
if (!re)
return nullptr;
pcre2_match_data *match_data =
pcre2_match_data_create_from_pattern(re, nullptr);
std::map<std::string, int> capture_name_cache;
Highlight *c_hl = nullptr;
Language c_lang = {"unknown", nullptr, 0};
int i = 0;
PCRE2_SIZE offset = 0;
PCRE2_SIZE subject_length = highlight_query.size();
while (offset < subject_length) {
int rc = pcre2_match(re, (PCRE2_SPTR)highlight_query.c_str(),
subject_length, offset, 0, match_data, nullptr);
if (rc <= 0)
break;
PCRE2_SIZE *ovector = pcre2_get_ovector_pointer(match_data);
std::string mct =
highlight_query.substr(ovector[0], ovector[1] - ovector[0]);
if (!mct.empty() && mct[0] == '@') {
std::string capture_name = mct;
if (!capture_name_cache.count(capture_name)) {
if (c_hl) {
set->query_map[i] = *c_hl;
delete c_hl;
c_hl = nullptr;
}
if (c_lang.fn != nullptr) {
set->injection_map[i] = c_lang;
c_lang = {"unknown", nullptr, 0};
}
capture_name_cache[capture_name] = i;
i++;
}
} else if (mct.substr(0, 4) == ";; #") {
if (c_hl)
delete c_hl;
c_hl = new Highlight();
c_hl->fg = HEX(mct.substr(4, 6));
c_hl->bg = HEX(mct.substr(12, 6));
int bold = std::stoi(mct.substr(19, 1));
int italic = std::stoi(mct.substr(21, 1));
int underline = std::stoi(mct.substr(23, 1));
c_hl->priority = std::stoi(mct.substr(25));
c_hl->flags = (bold ? CF_BOLD : 0) | (italic ? CF_ITALIC : 0) |
(underline ? CF_UNDERLINE : 0);
} else if (mct.substr(0, 4) == ";; !") {
auto it = kLanguages.find(mct.substr(4));
if (it != kLanguages.end())
c_lang = it->second;
else
c_lang = {"unknown", nullptr, 0};
}
offset = ovector[1];
}
if (c_hl)
delete c_hl;
pcre2_match_data_free(match_data);
pcre2_code_free(re);
uint32_t error_offset = 0;
TSQueryError error_type = (TSQueryError)0;
TSQuery *q = ts_query_new(lang, highlight_query.c_str(),
(uint32_t)highlight_query.length(), &error_offset,
&error_type);
if (!q)
log("Failed to create TSQuery at offset %u, error type %d", error_offset,
(int)error_type);
return q;
}
static inline const TSNode *find_capture_node(const TSQueryMatch &match,
uint32_t capture_id) {
for (uint32_t i = 0; i < match.capture_count; i++)
if (match.captures[i].index == capture_id)
return &match.captures[i].node;
return nullptr;
}
static inline std::string node_text(uint32_t start, uint32_t end,
Knot *source) {
char *text = read(source, start, end - start);
std::string final = std::string(text, end - start);
free(text);
return final;
}
bool ts_predicate(TSQuery *query, const TSQueryMatch &match,
std::function<std::string(const TSNode *)> subject_fn) {
uint32_t step_count;
const TSQueryPredicateStep *steps =
ts_query_predicates_for_pattern(query, match.pattern_index, &step_count);
if (!steps || step_count != 4)
return true;
std::string command;
std::string regex_txt;
uint32_t subject_id = 0;
for (uint32_t i = 0; i < step_count; i++) {
const TSQueryPredicateStep *step = &steps[i];
if (step->type == TSQueryPredicateStepTypeDone)
break;
switch (step->type) {
case TSQueryPredicateStepTypeString: {
uint32_t length = 0;
const char *s =
ts_query_string_value_for_id(query, step->value_id, &length);
if (i == 0)
command.assign(s, length);
else
regex_txt.assign(s, length);
break;
}
case TSQueryPredicateStepTypeCapture: {
subject_id = step->value_id;
break;
}
case TSQueryPredicateStepTypeDone:
break;
}
}
const TSNode *node = find_capture_node(match, subject_id);
pcre2_code *re = get_re(regex_txt);
std::string subject = subject_fn(node);
pcre2_match_data *md = pcre2_match_data_create_from_pattern(re, nullptr);
int rc = pcre2_match(re, (PCRE2_SPTR)subject.c_str(), subject.size(), 0, 0,
md, nullptr);
pcre2_match_data_free(md);
bool ok = (rc >= 0);
return (command == "match?" ? ok : !ok);
}
const char *read_ts(void *payload, uint32_t byte_index, TSPoint,
uint32_t *bytes_read) {
Editor *editor = (Editor *)payload;
if (byte_index >= editor->root->char_count) {
*bytes_read = 0;
return "";
}
return leaf_from_offset(editor->root, byte_index, bytes_read);
}
void ts_collect_spans(Editor *editor) {
static int parse_counter = 0;
if (!editor->ts.parser || !editor->root || !editor->ts.query)
return;
const bool injections_enabled = editor->root->char_count < (1024 * 32);
for (auto &inj : editor->ts.injections)
inj.second.ranges.clear();
TSInput tsinput{
.payload = editor,
.read = read_ts,
.encoding = TSInputEncodingUTF8,
.decode = nullptr,
};
std::vector<TSInputEdit> edits;
TSInputEdit edit;
if (!editor->edit_queue.empty()) {
while (editor->edit_queue.pop(edit))
edits.push_back(edit);
if (editor->ts.tree) {
for (auto &e : edits)
ts_tree_edit(editor->ts.tree, &e);
}
for (auto &inj : editor->ts.injections) {
if (inj.second.tree) {
for (auto &e : edits) {
TSInputEdit inj_edit = e;
for (auto &r : inj.second.ranges) {
if (e.start_byte >= r.start_byte && e.start_byte <= r.end_byte) {
inj_edit.start_byte -= r.start_byte;
inj_edit.old_end_byte -= r.start_byte;
inj_edit.new_end_byte -= r.start_byte;
}
}
ts_tree_edit(inj.second.tree, &inj_edit);
}
}
}
} else if (editor->ts.tree && parse_counter < 64) {
parse_counter++;
return;
}
parse_counter = 0;
editor->spans.mid_parse = true;
std::shared_lock lock(editor->knot_mtx);
TSTree *tree = ts_parser_parse(editor->ts.parser, editor->ts.tree, tsinput);
if (!tree)
return;
if (editor->ts.tree)
ts_tree_delete(editor->ts.tree);
editor->ts.tree = tree;
lock.unlock();
std::vector<Span> new_spans;
new_spans.reserve(4096);
struct PendingRanges {
std::vector<TSRange> ranges;
TSSet *tsset = nullptr;
};
struct WorkItem {
TSSetBase *tsset;
TSTree *tree;
int depth;
};
const int kMaxInjectionDepth = 4;
std::vector<WorkItem> work;
work.push_back(
{reinterpret_cast<TSSetBase *>(&editor->ts), editor->ts.tree, 0});
auto overlaps = [](const Span &s, const TSRange &r) {
return !(s.end <= r.start_byte || s.start >= r.end_byte);
};
auto remove_overlapping_spans = [&](const std::vector<TSRange> &ranges) {
if (ranges.empty())
return;
new_spans.erase(
std::remove_if(new_spans.begin(), new_spans.end(),
[&](const Span &sp) {
return std::any_of(
ranges.begin(), ranges.end(),
[&](const TSRange &r) { return overlaps(sp, r); });
}),
new_spans.end());
};
while (!work.empty()) {
WorkItem item = work.back();
work.pop_back();
TSQuery *q = item.tsset->query;
if (!q)
continue;
TSQueryCursor *cursor = ts_query_cursor_new();
ts_query_cursor_exec(cursor, q, ts_tree_root_node(item.tsset->tree));
std::unordered_map<std::string, PendingRanges> pending_injections;
TSQueryMatch match;
while (ts_query_cursor_next_match(cursor, &match)) {
auto subject_fn = [&](const TSNode *node) -> std::string {
uint32_t start = ts_node_start_byte(*node);
uint32_t end = ts_node_end_byte(*node);
return node_text(start, end, editor->root);
};
if (!ts_predicate(q, match, subject_fn))
continue;
for (uint32_t i = 0; i < match.capture_count; i++) {
TSQueryCapture cap = match.captures[i];
uint32_t start = ts_node_start_byte(cap.node);
uint32_t end = ts_node_end_byte(cap.node);
if (Highlight *hl = safe_get(item.tsset->query_map, cap.index))
new_spans.push_back({start, end, hl});
if (!injections_enabled)
continue;
if (Language *inj_lang =
safe_get(item.tsset->injection_map, cap.index)) {
auto &pending = pending_injections[inj_lang->name];
TSSet &tsset =
editor->ts.injections.try_emplace(inj_lang->name).first->second;
if (!tsset.parser) {
tsset.lang = inj_lang->name;
tsset.parser = ts_parser_new();
ts_parser_set_language(tsset.parser, inj_lang->fn());
tsset.language = inj_lang->fn();
tsset.query_file =
get_exe_dir() + "/../grammar/" + inj_lang->name + ".scm";
tsset.query = load_query(tsset.query_file.c_str(), &tsset);
}
pending.tsset = &tsset;
pending.ranges.push_back(TSRange{
ts_node_start_point(cap.node),
ts_node_end_point(cap.node),
start,
end,
});
}
}
}
ts_query_cursor_delete(cursor);
if (injections_enabled && item.depth < kMaxInjectionDepth) {
for (auto &[lang_name, pending] : pending_injections) {
TSSet *tsset = pending.tsset;
if (!tsset || pending.ranges.empty() || !tsset->parser || !tsset->query)
continue;
tsset->ranges = std::move(pending.ranges);
remove_overlapping_spans(tsset->ranges);
ts_parser_set_included_ranges(tsset->parser, tsset->ranges.data(),
tsset->ranges.size());
lock.lock();
TSTree *tree = ts_parser_parse(tsset->parser, tsset->tree, tsinput);
if (!tree)
continue;
if (tsset->tree)
ts_tree_delete(tsset->tree);
tsset->tree = tree;
lock.unlock();
work.push_back({reinterpret_cast<TSSetBase *>(tsset), tsset->tree,
item.depth + 1});
}
}
}
std::pair<uint32_t, int64_t> span_edit;
while (editor->spans.edits.pop(span_edit))
apply_edit(new_spans, span_edit.first, span_edit.second);
std::sort(new_spans.begin(), new_spans.end());
std::unique_lock span_mtx(editor->spans.mtx);
editor->spans.mid_parse = false;
editor->spans.spans.swap(new_spans);
}