From 533608f56b912aba98250a3c1501ee687d7cf5eb Mon Sep 17 00:00:00 2001 From: Liam Dyer Date: Tue, 10 Dec 2024 16:37:02 -0500 Subject: [PATCH] feat: sources v2 (#465) Large rewrite of how sources are handled, adding support for async providers/timeouts, tree based fallbacks, dynamically adding sources and some other goodies Closes #386 Closes #219 Closes #328 Closes #331 Closes #312 Closes #454 Closes #444 Closes #372 Closes #475 --- README.md | 67 ++++--- lua/blink/cmp/completion/list.lua | 58 ++++-- lua/blink/cmp/completion/trigger.lua | 26 ++- lua/blink/cmp/config/completion/trigger.lua | 8 +- lua/blink/cmp/config/fuzzy.lua | 6 +- lua/blink/cmp/config/init.lua | 1 + lua/blink/cmp/config/sources.lua | 73 ++++---- lua/blink/cmp/fuzzy/fuzzy.rs | 60 ++----- lua/blink/cmp/fuzzy/init.lua | 55 +++--- lua/blink/cmp/fuzzy/lib.rs | 38 +++- lua/blink/cmp/fuzzy/sort.lua | 44 +++++ lua/blink/cmp/init.lua | 16 +- lua/blink/cmp/lib/async.lua | 32 +++- lua/blink/cmp/lib/event_emitter.lua | 10 +- lua/blink/cmp/lib/utils.lua | 66 +++++++ lua/blink/cmp/sources/lib/context.lua | 155 ---------------- lua/blink/cmp/sources/lib/init.lua | 126 ++++++------- lua/blink/cmp/sources/lib/provider/config.lua | 16 +- lua/blink/cmp/sources/lib/provider/init.lua | 91 +++++----- lua/blink/cmp/sources/lib/provider/list.lua | 127 +++++++++++++ lua/blink/cmp/sources/lib/queue.lua | 68 +++++++ lua/blink/cmp/sources/lib/tree.lua | 168 ++++++++++++++++++ lua/blink/cmp/sources/lib/types.lua | 3 +- lua/blink/cmp/sources/lib/utils.lua | 48 ----- lua/blink/cmp/sources/lsp.lua | 148 +++++---------- lua/blink/cmp/sources/snippets/utils.lua | 8 +- 26 files changed, 922 insertions(+), 596 deletions(-) create mode 100644 lua/blink/cmp/fuzzy/sort.lua delete mode 100644 lua/blink/cmp/sources/lib/context.lua create mode 100644 lua/blink/cmp/sources/lib/provider/list.lua create mode 100644 lua/blink/cmp/sources/lib/queue.lua create mode 100644 lua/blink/cmp/sources/lib/tree.lua diff --git a/README.md b/README.md index 59fdd72c..e507d9ed 100644 --- a/README.md +++ b/README.md @@ -72,9 +72,7 @@ -- default list of enabled providers defined so that you can extend it -- elsewhere in your config, without redefining it, via `opts_extend` sources = { - completion = { - enabled_providers = { 'lsp', 'path', 'snippets', 'buffer' }, - }, + default = { 'lsp', 'path', 'snippets', 'buffer' }, }, -- experimental auto-brackets support @@ -83,9 +81,9 @@ -- experimental signature help support -- signature = { enabled = true } }, - -- allows extending the enabled_providers array elsewhere in your config + -- allows extending the providers array elsewhere in your config -- without having to redefine it - opts_extend = { "sources.completion.enabled_providers" } + opts_extend = { "sources.default" } }, ``` @@ -291,6 +289,14 @@ MiniDeps.add({ -- however, some LSPs (i.e. tsserver) return characters that would essentially -- always show the window. We block these by default. show_on_blocked_trigger_characters = { ' ', '\n', '\t' }, + -- or a function like + -- show_on_blocked_trigger_characters = function() + -- local blocked = { ' ', '\n', '\t' } + -- if vim.bo.filetype == 'markdown' then + -- vim.list_extend(blocked, { '.', '/', '(', '[' }) + -- end + -- return blocked + -- end -- When both this and show_on_trigger_character are true, will show the completion window -- when the cursor comes after a trigger character after accepting an item show_on_accept_on_trigger_character = true, @@ -301,6 +307,7 @@ MiniDeps.add({ -- the completion window when the cursor comes after a trigger character when -- entering insert mode/accepting an item show_on_x_blocked_trigger_characters = { "'", '"', '(' }, + -- or a function, similar to show_on_blocked_trigger_character }, list = { @@ -516,8 +523,9 @@ MiniDeps.add({ -- proximity bonus boosts the score of items matching nearby words use_proximity = true, max_items = 200, - -- controls which sorts to use and in which order, these three are currently the only allowed options - sorts = { 'label', 'kind', 'score' }, + -- controls which sorts to use and in which order, falling back to the next sort if the first one returns nil + -- you may pass a function instead of a string to customize the sorting + sorts = { 'score', 'kind', 'label' }, prebuilt_binaries = { -- Whether or not to automatically download a prebuilt binary from github. If this is set to `false` @@ -538,20 +546,23 @@ MiniDeps.add({ }, sources = { - completion = { - -- Static list of providers to enable, or a function to dynamically enable/disable providers based on the context - enabled_providers = { 'lsp', 'path', 'snippets', 'buffer' }, - -- Example dynamically picking providers based on the filetype and treesitter node: - -- enabled_providers = function(ctx) - -- local node = vim.treesitter.get_node() - -- if vim.bo.filetype == 'lua' then - -- return { 'lsp', 'path' } - -- elseif node and vim.tbl_contains({ 'comment', 'line_comment', 'block_comment' }, node:type()) then - -- return { 'buffer' } - -- else - -- return { 'lsp', 'path', 'snippets', 'buffer' } - -- end - -- end, + -- Static list of providers to enable, or a function to dynamically enable/disable providers based on the context + default = { 'lsp', 'path', 'snippets', 'buffer' }, + -- Example dynamically picking providers based on the filetype and treesitter node: + -- providers = function(ctx) + -- local node = vim.treesitter.get_node() + -- if vim.bo.filetype == 'lua' then + -- return { 'lsp', 'path' } + -- elseif node and vim.tbl_contains({ 'comment', 'line_comment', 'block_comment' }, node:type()) then + -- return { 'buffer' } + -- else + -- return { 'lsp', 'path', 'snippets', 'buffer' } + -- end + -- end + + -- You may also define providers per filetype + per_filetype = { + -- lua = { 'lsp', 'path' }, }, -- Please see https://github.com/Saghen/blink.compat for using `nvim-cmp` sources @@ -560,16 +571,19 @@ MiniDeps.add({ name = 'LSP', module = 'blink.cmp.sources.lsp', - --- *All* of the providers have the following options available + --- *All* providers have the following options available --- NOTE: All of these options may be functions to get dynamic behavior --- See the type definitions for more information. - --- Check the enabled_providers config for an example enabled = true, -- Whether or not to enable the provider + async = false, -- Whether we should wait for the provider to return before showing the completions + timeout_ms = 2000, -- How long to wait for the provider to return before showing completions and treating it as asynchronous transform_items = nil, -- Function to transform the items before they're returned should_show_items = true, -- Whether or not to show the items max_items = nil, -- Maximum number of items to display in the menu min_keyword_length = 0, -- Minimum number of characters in the keyword to trigger the provider - fallback_for = {}, -- If any of these providers return 0 items, it will fallback to this provider + -- If this provider returns 0 items, it will fallback to these providers. + -- If multiple providers falback to the same provider, all of the providers must return 0 items for it to fallback + fallbacks = { 'buffer' }, score_offset = 0, -- Boost/penalize the score of the items override = nil, -- Override the source's functions }, @@ -607,7 +621,6 @@ MiniDeps.add({ buffer = { name = 'Buffer', module = 'blink.cmp.sources.buffer', - fallback_for = { 'lsp' }, opts = { -- default to all visible buffers get_bufnrs = function() @@ -730,9 +743,7 @@ MiniDeps.add({ jump = function(direction) require('luasnip').jump(direction) end, }, sources = { - completion = { - enabled_providers = { 'lsp', 'path', 'luasnip', 'buffer' }, - }, + default = { 'lsp', 'path', 'luasnip', 'buffer' }, }, } } diff --git a/lua/blink/cmp/completion/list.lua b/lua/blink/cmp/completion/list.lua index 6e32214c..50eefd00 100644 --- a/lua/blink/cmp/completion/list.lua +++ b/lua/blink/cmp/completion/list.lua @@ -10,14 +10,15 @@ --- @field select_emitter blink.cmp.EventEmitter --- @field accept_emitter blink.cmp.EventEmitter --- ---- @field show fun(context: blink.cmp.Context, items?: blink.cmp.CompletionItem[]) ---- @field fuzzy fun(context: blink.cmp.Context, items: blink.cmp.CompletionItem[]): blink.cmp.CompletionItem[] +--- @field show fun(context: blink.cmp.Context, items: table) +--- @field fuzzy fun(context: blink.cmp.Context, items: table): blink.cmp.CompletionItem[] --- @field hide fun() --- --- @field get_selected_item fun(): blink.cmp.CompletionItem? ---- @field select fun(idx?: number, opts?: { undo_preview?: boolean }) +--- @field select fun(idx?: number, opts?: { undo_preview?: boolean, is_explicit_selection?: boolean }) --- @field select_next fun() --- @field select_prev fun() +--- @field get_item_idx_in_list fun(item?: blink.cmp.CompletionItem): number --- --- @field undo_preview fun() --- @field apply_preview fun(item: blink.cmp.CompletionItem) @@ -54,18 +55,31 @@ local list = { context = nil, items = {}, selected_item_idx = nil, + is_explicitly_selected = false, preview_undo_text_edit = nil, } ---------- State ---------- -function list.show(context, items) +function list.show(context, items_by_source) -- reset state for new context local is_new_context = not list.context or list.context.id ~= context.id - if is_new_context then list.preview_undo_text_edit = nil end + if is_new_context then + list.preview_undo_text_edit = nil + list.is_explicitly_selected = false + end + + -- if the keyword changed, the list is no longer explicitly selected + local bounds_equal = list.context ~= nil + and list.context.bounds.start_col == context.bounds.start_col + and list.context.bounds.length == context.bounds.length + if not bounds_equal then list.is_explicitly_selected = false end + + local previous_selected_item = list.get_selected_item() + -- update the context/list and emit list.context = context - list.items = list.fuzzy(context, items or list.items) + list.items = list.fuzzy(context, items_by_source) if #list.items == 0 then list.hide_emitter:emit({ context = context }) @@ -73,16 +87,29 @@ function list.show(context, items) list.show_emitter:emit({ items = list.items, context = context }) end - -- todo: some logic to maintain the selection if the user moved the cursor? - list.select(list.config.selection == 'preselect' and 1 or nil, { undo_preview = false }) + -- maintain the selection if the user selected an item + local previous_item_idx = list.get_item_idx_in_list(previous_selected_item) + if list.is_explicitly_selected and previous_item_idx ~= nil and previous_item_idx <= 10 then + list.select(previous_item_idx, { undo_preview = false }) + + -- otherwise, use the default selection + else + list.select( + list.config.selection == 'preselect' and 1 or nil, + { undo_preview = false, is_explicit_selection = false } + ) + end end -function list.fuzzy(context, items) +function list.fuzzy(context, items_by_source) local fuzzy = require('blink.cmp.fuzzy') - local sources = require('blink.cmp.sources.lib') + local filtered_items = fuzzy.fuzzy(fuzzy.get_query(), items_by_source) + + -- apply the per source max_items + filtered_items = require('blink.cmp.sources.lib').apply_max_items_for_completions(context, filtered_items) - local filtered_items = fuzzy.fuzzy(fuzzy.get_query(), items) - return sources.apply_max_items_for_completions(context, filtered_items) + -- apply the global max_items + return require('blink.cmp.lib.utils').slice(filtered_items, 1, list.config.max_items) end function list.hide() list.hide_emitter:emit({ context = list.context }) end @@ -101,6 +128,8 @@ function list.select(idx, opts) if list.config.selection == 'auto_insert' and item then list.apply_preview(item) end end) + --- @diagnostic disable-next-line: assign-type-mismatch + list.is_explicitly_selected = opts.is_explicit_selection == nil and true or opts.is_explicit_selection list.selected_item_idx = idx list.select_emitter:emit({ idx = idx, item = item, items = list.items, context = list.context }) end @@ -149,6 +178,11 @@ function list.select_prev() list.select(list.selected_item_idx - 1) end +function list.get_item_idx_in_list(item) + if item == nil then return end + return require('blink.cmp.lib.utils').find_idx(list.items, function(i) return i.label == item.label end) +end + ---------- Preview ---------- function list.undo_preview() diff --git a/lua/blink/cmp/completion/trigger.lua b/lua/blink/cmp/completion/trigger.lua index 93b8fdc2..034e9b4c 100644 --- a/lua/blink/cmp/completion/trigger.lua +++ b/lua/blink/cmp/completion/trigger.lua @@ -16,6 +16,7 @@ --- @field line string --- @field bounds blink.cmp.ContextBounds --- @field trigger { kind: number, character: string | nil } +--- @field providers string[] --- @class blink.cmp.CompletionTrigger --- @field buffer_events blink.cmp.BufferEvents @@ -28,7 +29,7 @@ --- @field is_trigger_character fun(char: string, is_retrigger?: boolean): boolean --- @field suppress_events_for_callback fun(cb: fun()) --- @field show_if_on_trigger_character fun(opts?: { is_accept?: boolean }) ---- @field show fun(opts?: { trigger_character?: string, force?: boolean, send_upstream?: boolean }) +--- @field show fun(opts?: { trigger_character?: string, force?: boolean, send_upstream?: boolean, providers?: string[] }) --- @field hide fun() --- @field within_query_bounds fun(cursor: number[]): boolean --- @field get_context_bounds fun(regex: string): blink.cmp.ContextBounds @@ -119,8 +120,17 @@ function trigger.is_trigger_character(char, is_show_on_x) local sources = require('blink.cmp.sources.lib') local is_trigger = vim.tbl_contains(sources.get_trigger_characters(), char) - local is_blocked = vim.tbl_contains(config.show_on_blocked_trigger_characters, char) - or (is_show_on_x and vim.tbl_contains(config.show_on_x_blocked_trigger_characters, char)) + local show_on_blocked_trigger_characters = type(config.show_on_blocked_trigger_characters) == 'function' + and config.show_on_blocked_trigger_characters() + or config.show_on_blocked_trigger_characters + --- @cast show_on_blocked_trigger_characters string[] + local show_on_x_blocked_trigger_characters = type(config.show_on_x_blocked_trigger_characters) == 'function' + and config.show_on_x_blocked_trigger_characters() + or config.show_on_x_blocked_trigger_characters + --- @cast show_on_x_blocked_trigger_characters string[] + + local is_blocked = vim.tbl_contains(show_on_blocked_trigger_characters, char) + or (is_show_on_x and vim.tbl_contains(show_on_x_blocked_trigger_characters, char)) return is_trigger and not is_blocked end @@ -165,7 +175,14 @@ function trigger.show(opts) end -- update context - if trigger.context == nil then trigger.current_context_id = trigger.current_context_id + 1 end + if trigger.context == nil or opts.providers ~= nil then + trigger.current_context_id = trigger.current_context_id + 1 + end + + local providers = opts.providers + or (trigger.context and trigger.context.providers) + or require('blink.cmp.sources.lib').get_enabled_provider_ids() + trigger.context = { id = trigger.current_context_id, bufnr = vim.api.nvim_get_current_buf(), @@ -177,6 +194,7 @@ function trigger.show(opts) or vim.lsp.protocol.CompletionTriggerKind.Invoked, character = opts.trigger_character, }, + providers = providers, } if opts.send_upstream ~= false then trigger.show_emitter:emit({ context = trigger.context }) end diff --git a/lua/blink/cmp/config/completion/trigger.lua b/lua/blink/cmp/config/completion/trigger.lua index 9c982ad9..5ea984e6 100644 --- a/lua/blink/cmp/config/completion/trigger.lua +++ b/lua/blink/cmp/config/completion/trigger.lua @@ -2,10 +2,10 @@ --- @field show_in_snippet boolean When false, will not show the completion window when in a snippet --- @field show_on_keyword boolean When true, will show the completion window after typing a character that matches the `keyword.regex` --- @field show_on_trigger_character boolean When true, will show the completion window after typing a trigger character ---- @field show_on_blocked_trigger_characters string[] LSPs can indicate when to show the completion window via trigger characters. However, some LSPs (i.e. tsserver) return characters that would essentially always show the window. We block these by default. +--- @field show_on_blocked_trigger_characters string[] | (fun(): string[]) LSPs can indicate when to show the completion window via trigger characters. However, some LSPs (i.e. tsserver) return characters that would essentially always show the window. We block these by default. --- @field show_on_accept_on_trigger_character boolean When both this and show_on_trigger_character are true, will show the completion window when the cursor comes after a trigger character after accepting an item --- @field show_on_insert_on_trigger_character boolean When both this and show_on_trigger_character are true, will show the completion window when the cursor comes after a trigger character when entering insert mode ---- @field show_on_x_blocked_trigger_characters string[] List of trigger characters (on top of `show_on_blocked_trigger_characters`) that won't trigger the completion window when the cursor comes after a trigger character when entering insert mode/accepting an item +--- @field show_on_x_blocked_trigger_characters string[] | (fun(): string[]) List of trigger characters (on top of `show_on_blocked_trigger_characters`) that won't trigger the completion window when the cursor comes after a trigger character when entering insert mode/accepting an item local validate = require('blink.cmp.config.utils').validate local trigger = { @@ -26,10 +26,10 @@ function trigger.validate(config) show_in_snippet = { config.show_in_snippet, 'boolean' }, show_on_keyword = { config.show_on_keyword, 'boolean' }, show_on_trigger_character = { config.show_on_trigger_character, 'boolean' }, - show_on_blocked_trigger_characters = { config.show_on_blocked_trigger_characters, 'table' }, + show_on_blocked_trigger_characters = { config.show_on_blocked_trigger_characters, { 'function', 'table' } }, show_on_accept_on_trigger_character = { config.show_on_accept_on_trigger_character, 'boolean' }, show_on_insert_on_trigger_character = { config.show_on_insert_on_trigger_character, 'boolean' }, - show_on_x_blocked_trigger_characters = { config.show_on_x_blocked_trigger_characters, 'table' }, + show_on_x_blocked_trigger_characters = { config.show_on_x_blocked_trigger_characters, { 'function', 'table' } }, }) end diff --git a/lua/blink/cmp/config/fuzzy.lua b/lua/blink/cmp/config/fuzzy.lua index b91727d6..bbe53540 100644 --- a/lua/blink/cmp/config/fuzzy.lua +++ b/lua/blink/cmp/config/fuzzy.lua @@ -2,7 +2,7 @@ --- @field use_typo_resistance boolean When enabled, allows for a number of typos relative to the length of the query. Disabling this matches the behavior of fzf --- @field use_frecency boolean Tracks the most recently/frequently used items and boosts the score of the item --- @field use_proximity boolean Boosts the score of items matching nearby words ---- @field sorts ("label" | "kind" | "score")[] Controls which sorts to use and in which order, these three are currently the only allowed options +--- @field sorts ("label" | "kind" | "score" | blink.cmp.SortFunction)[] Controls which sorts to use and in which order, these three are currently the only allowed options --- @field prebuilt_binaries blink.cmp.PrebuiltBinariesConfig --- @class (exact) blink.cmp.PrebuiltBinariesConfig @@ -10,6 +10,8 @@ --- @field force_version? string When downloading a prebuilt binary, force the downloader to resolve this version. If this is unset then the downloader will attempt to infer the version from the checked out git tag (if any). WARN: Beware that `main` may be incompatible with the version you select --- @field force_system_triple? string When downloading a prebuilt binary, force the downloader to use this system triple. If this is unset then the downloader will attempt to infer the system triple from `jit.os` and `jit.arch`. Check the latest release for all available system triples. WARN: Beware that `main` may be incompatible with the version you select +--- @alias blink.cmp.SortFunction fun(a: blink.cmp.CompletionItem, b: blink.cmp.CompletionItem): boolean | nil + local validate = require('blink.cmp.config.utils').validate local fuzzy = { --- @type blink.cmp.FuzzyConfig @@ -17,7 +19,7 @@ local fuzzy = { use_typo_resistance = true, use_frecency = true, use_proximity = true, - sorts = { 'label', 'kind', 'score' }, + sorts = { 'score', 'kind', 'label' }, prebuilt_binaries = { download = true, force_version = nil, diff --git a/lua/blink/cmp/config/init.lua b/lua/blink/cmp/config/init.lua index f446a9f9..c9e40379 100644 --- a/lua/blink/cmp/config/init.lua +++ b/lua/blink/cmp/config/init.lua @@ -2,6 +2,7 @@ --- @field enabled fun(): boolean --- @field keymap blink.cmp.KeymapConfig --- @field completion blink.cmp.CompletionConfig +--- @field fuzzy blink.cmp.FuzzyConfig --- @field sources blink.cmp.SourceConfig --- @field signature blink.cmp.SignatureConfig --- @field snippets blink.cmp.SnippetsConfig diff --git a/lua/blink/cmp/config/sources.lua b/lua/blink/cmp/config/sources.lua index cbed0a5d..1acea657 100644 --- a/lua/blink/cmp/config/sources.lua +++ b/lua/blink/cmp/config/sources.lua @@ -1,8 +1,4 @@ --- @class blink.cmp.SourceConfig ---- @field completion blink.cmp.SourceModeConfig ---- @field providers table - ---- @class blink.cmp.SourceModeConfig --- Static list of providers to enable, or a function to dynamically enable/disable providers based on the context --- --- Example dynamically picking providers based on the filetype and treesitter node: @@ -18,18 +14,21 @@ --- end --- end --- ``` ---- @field enabled_providers string[] | fun(ctx?: blink.cmp.Context): string[] +--- @field default string[] | fun(): string[] +--- @field per_filetype table +--- @field providers table --- @class blink.cmp.SourceProviderConfig --- @field name? string --- @field module? string --- @field enabled? boolean | fun(ctx?: blink.cmp.Context): boolean Whether or not to enable the provider --- @field opts? table +--- @field async? boolean Whether blink should wait for the source to return before showing the completions --- @field transform_items? fun(ctx: blink.cmp.Context, items: blink.cmp.CompletionItem[]): blink.cmp.CompletionItem[] Function to transform the items before they're returned ---- @field should_show_items? boolean | number | fun(ctx: blink.cmp.Context, items: blink.cmp.CompletionItem[]): boolean Whether or not to show the items ---- @field max_items? number | fun(ctx: blink.cmp.Context, enabled_sources: string[], items: blink.cmp.CompletionItem[]): number Maximum number of items to display in the menu ---- @field min_keyword_length? number | fun(ctx: blink.cmp.Context, enabled_sources: string[]): number Minimum number of characters in the keyword to trigger the provider ---- @field fallback_for? string[] | fun(ctx: blink.cmp.Context, enabled_sources: string[]): string[] If any of these providers return 0 items, it will fallback to this provider +--- @field should_show_items? boolean | fun(ctx: blink.cmp.Context, items: blink.cmp.CompletionItem[]): boolean Whether or not to show the items +--- @field max_items? number | fun(ctx: blink.cmp.Context, items: blink.cmp.CompletionItem[]): number Maximum number of items to display in the menu +--- @field min_keyword_length? number | fun(ctx: blink.cmp.Context): number Minimum number of characters in the keyword to trigger the provider +--- @field fallbacks? string[] | fun(ctx: blink.cmp.Context, enabled_sources: string[]): string[] If this provider returns 0 items, it will fallback to these providers --- @field score_offset? number | fun(ctx: blink.cmp.Context, enabled_sources: string[]): number Boost/penalize the score of the items --- @field deduplicate? blink.cmp.DeduplicateConfig TODO: implement --- @field override? blink.cmp.SourceOverride Override the source's functions @@ -38,13 +37,13 @@ local validate = require('blink.cmp.config.utils').validate local sources = { --- @type blink.cmp.SourceConfig default = { - completion = { - enabled_providers = { 'lsp', 'path', 'snippets', 'buffer' }, - }, + default = { 'lsp', 'path', 'snippets', 'buffer' }, + per_filetype = {}, providers = { lsp = { name = 'LSP', module = 'blink.cmp.sources.lsp', + fallbacks = { 'buffer' }, }, path = { name = 'Path', @@ -64,7 +63,6 @@ local sources = { buffer = { name = 'Buffer', module = 'blink.cmp.sources.buffer', - fallback_for = { 'lsp' }, }, }, }, @@ -72,28 +70,39 @@ local sources = { function sources.validate(config) validate('sources', { - completion = { config.completion, 'table' }, + default = { config.default, { 'function', 'table' } }, + per_filetype = { config.per_filetype, 'table' }, providers = { config.providers, 'table' }, }) - validate('sources.completion', { - enabled_providers = { config.completion.enabled_providers, { 'table', 'function' } }, - }) - for key, provider in pairs(config.providers) do - validate('sources.providers.' .. key, { - name = { provider.name, 'string' }, - module = { provider.module, 'string' }, - enabled = { provider.enabled, { 'boolean', 'function' }, true }, - opts = { provider.opts, 'table', true }, - transform_items = { provider.transform_items, 'function', true }, - should_show_items = { provider.should_show_items, { 'boolean', 'function' }, true }, - max_items = { provider.max_items, { 'number', 'function' }, true }, - min_keyword_length = { provider.min_keyword_length, { 'number', 'function' }, true }, - fallback_for = { provider.fallback_for, { 'table', 'function' }, true }, - score_offset = { provider.score_offset, { 'number', 'function' }, true }, - deduplicate = { provider.deduplicate, 'table', true }, - override = { provider.override, 'table', true }, - }) + assert( + config.completion == nil, + '`sources.completion.enabled_providers` has been replaced with `sources.default`. !!Note!! Be sure to update `opts_extend` as well if you have it set' + ) + for id, provider in pairs(config.providers) do + sources.validate_provider(id, provider) end end +function sources.validate_provider(id, provider) + assert( + provider.fallback_for == nil, + '`fallback_for` has been replaced with `fallbacks` which work in the opposite direction. For example, fallback_for = { "lsp" } on "buffer" would now be "fallbacks" = { "buffer" } on "lsp"' + ) + + validate('sources.providers.' .. id, { + name = { provider.name, 'string' }, + module = { provider.module, 'string' }, + enabled = { provider.enabled, { 'boolean', 'function' }, true }, + opts = { provider.opts, 'table', true }, + transform_items = { provider.transform_items, 'function', true }, + should_show_items = { provider.should_show_items, { 'boolean', 'function' }, true }, + max_items = { provider.max_items, { 'number', 'function' }, true }, + min_keyword_length = { provider.min_keyword_length, { 'number', 'function' }, true }, + fallbacks = { provider.fallback_for, { 'table', 'function' }, true }, + score_offset = { provider.score_offset, { 'number', 'function' }, true }, + deduplicate = { provider.deduplicate, 'table', true }, + override = { provider.override, 'table', true }, + }) +end + return sources diff --git a/lua/blink/cmp/fuzzy/fuzzy.rs b/lua/blink/cmp/fuzzy/fuzzy.rs index 32b078e2..4d024935 100644 --- a/lua/blink/cmp/fuzzy/fuzzy.rs +++ b/lua/blink/cmp/fuzzy/fuzzy.rs @@ -5,7 +5,6 @@ use crate::lsp_item::LspItem; use mlua::prelude::*; use mlua::FromLua; use mlua::Lua; -use std::cmp::Reverse; use std::collections::HashSet; #[derive(Clone, Hash)] @@ -15,8 +14,6 @@ pub struct FuzzyOptions { use_proximity: bool, nearby_words: Option>, min_score: u16, - max_items: u32, - sorts: Vec, } impl FromLua for FuzzyOptions { @@ -27,8 +24,6 @@ impl FromLua for FuzzyOptions { let use_proximity: bool = tab.get("use_proximity").unwrap_or_default(); let nearby_words: Option> = tab.get("nearby_words").ok(); let min_score: u16 = tab.get("min_score").unwrap_or_default(); - let max_items: u32 = tab.get("max_items").unwrap_or_default(); - let sorts: Vec = tab.get("sorts").unwrap_or_default(); Ok(FuzzyOptions { use_typo_resistance, @@ -36,8 +31,6 @@ impl FromLua for FuzzyOptions { use_proximity, nearby_words, min_score, - max_items, - sorts, }) } else { Err(mlua::Error::FromLuaConversionError { @@ -51,10 +44,10 @@ impl FromLua for FuzzyOptions { pub fn fuzzy( needle: String, - haystack: Vec, + haystack: &Vec, frecency: &FrecencyTracker, opts: FuzzyOptions, -) -> Vec { +) -> (Vec, Vec) { let nearby_words: HashSet = HashSet::from_iter(opts.nearby_words.unwrap_or_default()); let haystack_labels = haystack .iter() @@ -110,42 +103,15 @@ pub fn fuzzy( .collect::>(); } - // Sort matches by sort criteria - for sort in opts.sorts.iter() { - match sort.as_str() { - "kind" => { - matches.sort_by_key(|mtch| haystack[mtch.index_in_haystack].kind); - } - "score" => { - matches.sort_by_cached_key(|mtch| Reverse(match_scores[mtch.index])); - } - "label" => { - matches.sort_by(|a, b| { - let label_a = haystack[a.index_in_haystack] - .sort_text - .as_ref() - .unwrap_or(&haystack[a.index_in_haystack].label); - let label_b = haystack[b.index_in_haystack] - .sort_text - .as_ref() - .unwrap_or(&haystack[b.index_in_haystack].label); - - // Put anything with an underscore at the end - match (label_a.starts_with('_'), label_b.starts_with('_')) { - (true, false) => std::cmp::Ordering::Greater, - (false, true) => std::cmp::Ordering::Less, - _ => label_a.cmp(label_b), - } - }); - } - _ => {} - } - } - - // Grab the top N matches and return the indices - matches - .iter() - .map(|mtch| mtch.index_in_haystack) - .take(opts.max_items as usize) - .collect::>() + // Return scores and indices + ( + matches + .iter() + .map(|mtch| match_scores[mtch.index] as i32) + .collect::>(), + matches + .iter() + .map(|mtch| mtch.index_in_haystack as u32) + .collect::>(), + ) } diff --git a/lua/blink/cmp/fuzzy/init.lua b/lua/blink/cmp/fuzzy/init.lua index 7f41833f..31fdfa40 100644 --- a/lua/blink/cmp/fuzzy/init.lua +++ b/lua/blink/cmp/fuzzy/init.lua @@ -2,6 +2,7 @@ local config = require('blink.cmp.config') local fuzzy = { rust = require('blink.cmp.fuzzy.rust'), + haystacks_by_provider_cache = {}, has_init_db = false, } @@ -24,13 +25,19 @@ function fuzzy.get_words(lines) return fuzzy.rust.get_words(lines) end function fuzzy.fuzzy_matched_indices(needle, haystack) return fuzzy.rust.fuzzy_matched_indices(needle, haystack) end ----@param needle string ----@param haystack blink.cmp.CompletionItem[]? ----@return blink.cmp.CompletionItem[] -function fuzzy.fuzzy(needle, haystack) +--- @param needle string +--- @param haystacks_by_provider table +--- @return blink.cmp.CompletionItem[] +function fuzzy.fuzzy(needle, haystacks_by_provider) fuzzy.init_db() - haystack = haystack or {} + for provider_id, haystack in pairs(haystacks_by_provider) do + -- set the provider items once since Lua <-> Rust takes the majority of the time + if fuzzy.haystacks_by_provider_cache[provider_id] ~= haystack then + fuzzy.haystacks_by_provider_cache[provider_id] = haystack + fuzzy.rust.set_provider_items(provider_id, haystack) + end + end -- get the nearby words local cursor_row = vim.api.nvim_win_get_cursor(0)[1] @@ -39,25 +46,29 @@ function fuzzy.fuzzy(needle, haystack) local nearby_text = table.concat(vim.api.nvim_buf_get_lines(0, start_row, end_row, false), '\n') local nearby_words = #nearby_text < 10000 and fuzzy.rust.get_words(nearby_text) or {} - -- perform fuzzy search - local matched_indices = fuzzy.rust.fuzzy(needle, haystack, { - -- each matching char is worth 4 points and it receives a bonus for capitalization, delimiter and prefix - -- so this should generally be good - -- TODO: make this configurable - min_score = config.fuzzy.use_typo_resistance and (6 * needle:len()) or 0, - max_items = config.completion.list.max_items, - use_typo_resistance = config.fuzzy.use_typo_resistance, - use_frecency = config.fuzzy.use_frecency, - use_proximity = config.fuzzy.use_proximity, - sorts = config.fuzzy.sorts, - nearby_words = nearby_words, - }) - local filtered_items = {} - for _, idx in ipairs(matched_indices) do - table.insert(filtered_items, haystack[idx + 1]) + for provider_id, haystack in pairs(haystacks_by_provider) do + -- perform fuzzy search + local scores, matched_indices = fuzzy.rust.fuzzy(needle, provider_id, { + -- each matching char is worth 4 points and it receives a bonus for capitalization, delimiter and prefix + -- so this should generally be good + -- TODO: make this configurable + min_score = config.fuzzy.use_typo_resistance and (6 * needle:len()) or 0, + use_typo_resistance = config.fuzzy.use_typo_resistance, + use_frecency = config.fuzzy.use_frecency, + use_proximity = config.fuzzy.use_proximity, + sorts = config.fuzzy.sorts, + nearby_words = nearby_words, + }) + + for idx, item_index in ipairs(matched_indices) do + local item = haystack[item_index + 1] + item.score = scores[idx] + table.insert(filtered_items, item) + end end - return filtered_items + + return require('blink.cmp.fuzzy.sort').sort(filtered_items) end --- Gets the text under the cursor to be used for fuzzy matching diff --git a/lua/blink/cmp/fuzzy/lib.rs b/lua/blink/cmp/fuzzy/lib.rs index 7f193210..353f235b 100644 --- a/lua/blink/cmp/fuzzy/lib.rs +++ b/lua/blink/cmp/fuzzy/lib.rs @@ -4,7 +4,7 @@ use crate::lsp_item::LspItem; use lazy_static::lazy_static; use mlua::prelude::*; use regex::Regex; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::sync::RwLock; mod frecency; @@ -14,6 +14,8 @@ mod lsp_item; lazy_static! { static ref REGEX: Regex = Regex::new(r"\p{L}[\p{L}0-9_\\-]{2,32}").unwrap(); static ref FRECENCY: RwLock> = RwLock::new(None); + static ref HAYSTACKS_BY_PROVIDER: RwLock>> = + RwLock::new(HashMap::new()); } pub fn init_db(_: &Lua, db_path: String) -> LuaResult { @@ -52,10 +54,21 @@ pub fn access(_: &Lua, item: LspItem) -> LuaResult { Ok(true) } +pub fn set_provider_items( + _: &Lua, + (provider_id, items): (String, Vec), +) -> LuaResult { + let mut items_by_provider = HAYSTACKS_BY_PROVIDER.write().map_err(|_| { + mlua::Error::RuntimeError("Failed to acquire lock for items by provider".to_string()) + })?; + items_by_provider.insert(provider_id, items); + Ok(true) +} + pub fn fuzzy( _lua: &Lua, - (needle, haystack, opts): (String, Vec, FuzzyOptions), -) -> LuaResult> { + (needle, provider_id, opts): (String, String, FuzzyOptions), +) -> LuaResult<(Vec, Vec)> { let mut frecency_handle = FRECENCY.write().map_err(|_| { mlua::Error::RuntimeError("Failed to acquire lock for frecency".to_string()) })?; @@ -63,10 +76,17 @@ pub fn fuzzy( mlua::Error::RuntimeError("Attempted to use frencecy before initialization".to_string()) })?; - Ok(fuzzy::fuzzy(needle, haystack, frecency, opts) - .into_iter() - .map(|i| i as u32) - .collect()) + let haystacks_by_provider = HAYSTACKS_BY_PROVIDER.read().map_err(|_| { + mlua::Error::RuntimeError("Failed to acquire lock for items by provider".to_string()) + })?; + let haystack = haystacks_by_provider.get(&provider_id).ok_or_else(|| { + mlua::Error::RuntimeError(format!( + "Attempted to fuzzy match for provider {} before setting the provider's items", + provider_id + )) + })?; + + Ok(fuzzy::fuzzy(needle, haystack, frecency, opts)) } pub fn fuzzy_matched_indices( @@ -93,6 +113,10 @@ pub fn get_words(_: &Lua, text: String) -> LuaResult> { #[mlua::lua_module(skip_memory_check)] fn blink_cmp_fuzzy(lua: &Lua) -> LuaResult { let exports = lua.create_table()?; + exports.set( + "set_provider_items", + lua.create_function(set_provider_items)?, + )?; exports.set("fuzzy", lua.create_function(fuzzy)?)?; exports.set( "fuzzy_matched_indices", diff --git a/lua/blink/cmp/fuzzy/sort.lua b/lua/blink/cmp/fuzzy/sort.lua new file mode 100644 index 00000000..ecfc8d41 --- /dev/null +++ b/lua/blink/cmp/fuzzy/sort.lua @@ -0,0 +1,44 @@ +local sort = {} + +--- @param list blink.cmp.CompletionItem[] +--- @return blink.cmp.CompletionItem[] +function sort.sort(list) + local config = require('blink.cmp.config').fuzzy.sorts + local sorting_funcs = vim.tbl_map( + function(name_or_func) return type(name_or_func) == 'string' and sort[name_or_func] or name_or_func end, + config + ) + table.sort(list, function(a, b) + for _, sorting_func in ipairs(sorting_funcs) do + local result = sorting_func(a, b) + if result ~= nil then return result end + end + end) + return list +end + +function sort.score(a, b) + if a.score == b.score then return end + return a.score > b.score +end + +function sort.kind(a, b) + if a.kind == b.kind then return end + return a.kind < b.kind +end + +function sort.label(a, b) + local label_a = a.sortText or a.label + local label_b = b.sortText or b.label + local _, entry1_under = label_a:find('^_+') + local _, entry2_under = label_b:find('^_+') + entry1_under = entry1_under or 0 + entry2_under = entry2_under or 0 + if entry1_under > entry2_under then + return false + elseif entry1_under < entry2_under then + return true + end +end + +return sort diff --git a/lua/blink/cmp/init.lua b/lua/blink/cmp/init.lua index 85278f1a..e4f2d198 100644 --- a/lua/blink/cmp/init.lua +++ b/lua/blink/cmp/init.lua @@ -19,12 +19,13 @@ end ------- Public API ------- -function cmp.show() - if require('blink.cmp.completion.windows.menu').win:is_open() then return end +--- @params opts? { providers?: string[] } +function cmp.show(opts) + if require('blink.cmp.completion.windows.menu').win:is_open() and not (opts and opts.providers) then return end vim.schedule(function() require('blink.cmp.completion.windows.menu').auto_show = true - require('blink.cmp.completion.trigger').show({ force = true }) + require('blink.cmp.completion.trigger').show({ force = true, providers = opts and opts.providers }) end) return true end @@ -146,4 +147,13 @@ function cmp.get_lsp_capabilities(override, include_nvim_defaults) return require('blink.cmp.sources.lib').get_lsp_capabilities(override, include_nvim_defaults) end +--- @param id string +--- @param provider_config blink.cmp.SourceProviderConfig +function cmp.add_provider(id, provider_config) + local config = require('blink.cmp.config') + assert(config.sources.providers[id] == nil, 'Provider with id ' .. id .. ' already exists') + require('blink.cmp.config.sources').validate_provider(id, provider_config) + config.sources.providers[id] = provider_config +end + return cmp diff --git a/lua/blink/cmp/lib/async.lua b/lua/blink/cmp/lib/async.lua index 45c2fb62..d4155e2e 100644 --- a/lua/blink/cmp/lib/async.lua +++ b/lua/blink/cmp/lib/async.lua @@ -151,33 +151,51 @@ end --- utils function task.await_all(tasks) - return task.new(function(resolve) + if #tasks == 0 then return task.empty() end + + local all_task + all_task = task.new(function(resolve, reject) local results = {} + local has_resolved = {} local function resolve_if_completed() -- we can't check #results directly because a table like -- { [2] = { ... } } has a length of 2 for i = 1, #tasks do - if results[i] == nil then return end + if has_resolved[i] == nil then return end end resolve(results) end for idx, task in ipairs(tasks) do task:on_completion(function(result) - results[idx] = { status = STATUS.COMPLETED, result = result } + results[idx] = result + has_resolved[idx] = true resolve_if_completed() end) task:on_failure(function(err) - results[idx] = { status = STATUS.FAILED, err = err } - resolve_if_completed() + for _, task in ipairs(tasks) do + task:cancel() + end + reject(err) end) task:on_cancel(function() - results[idx] = { status = STATUS.CANCELLED } - resolve_if_completed() + for _, sub_task in ipairs(tasks) do + sub_task:cancel() + end + if all_task == nil then + vim.schedule(function() all_task:cancel() end) + else + all_task:cancel() + end end) end end) + return all_task +end + +function task.empty() + return task.new(function(resolve) resolve() end) end return { task = task, STATUS = STATUS } diff --git a/lua/blink/cmp/lib/event_emitter.lua b/lua/blink/cmp/lib/event_emitter.lua index d3939cb9..0387a8f7 100644 --- a/lua/blink/cmp/lib/event_emitter.lua +++ b/lua/blink/cmp/lib/event_emitter.lua @@ -27,11 +27,11 @@ function event_emitter:emit(data) for _, callback in ipairs(self.listeners) do callback(data) end - if self.autocmd then - require('blink.cmp.lib.utils').schedule_if_needed( - function() vim.api.nvim_exec_autocmds('User', { pattern = self.autocmd, modeline = false, data = data }) end - ) - end + -- if self.autocmd then + -- require('blink.cmp.lib.utils').schedule_if_needed( + -- function() vim.api.nvim_exec_autocmds('User', { pattern = self.autocmd, modeline = false, data = data }) end + -- ) + -- end end return event_emitter diff --git a/lua/blink/cmp/lib/utils.lua b/lua/blink/cmp/lib/utils.lua index ad7a206c..8f3233c2 100644 --- a/lua/blink/cmp/lib/utils.lua +++ b/lua/blink/cmp/lib/utils.lua @@ -93,4 +93,70 @@ function utils.schedule_if_needed(fn) end end +--- Flattens an arbitrarily deep table into a single level table +--- @param t table +--- @return table +function utils.flatten(t) + if t[1] == nil then return t end + + local flattened = {} + for _, v in ipairs(t) do + if v[1] == nil then + table.insert(flattened, v) + else + vim.list_extend(flattened, utils.flatten(v)) + end + end + return flattened +end + +--- Returns the index of the first occurrence of the value in the array +--- @generic T +--- @param arr T[] +--- @param val T +--- @return number | nil +function utils.index_of(arr, val) + for idx, v in ipairs(arr) do + if v == val then return idx end + end + return nil +end + +--- Finds an item in an array using a predicate function +--- @generic T +--- @param arr T[] +--- @param predicate fun(item: T): boolean +--- @return T | nil +function utils.find_idx(arr, predicate) + for idx, v in ipairs(arr) do + if predicate(v) then return idx end + end + return nil +end + +--- Slices an array +--- @generic T +--- @param arr T[] +--- @param start number +--- @param finish number +--- @return T[] +function utils.slice(arr, start, finish) + start = start or 1 + finish = finish or #arr + local sliced = {} + for i = start, finish do + sliced[#sliced + 1] = arr[i] + end + return sliced +end + +function utils.fast_gsub(str, old_char, new_char) + local result = '' + for i = 1, #str do + local c = str:sub(i, i) + result = result .. (c == old_char and new_char or c) + end + return result +end + return utils diff --git a/lua/blink/cmp/sources/lib/context.lua b/lua/blink/cmp/sources/lib/context.lua deleted file mode 100644 index f196fcc8..00000000 --- a/lua/blink/cmp/sources/lib/context.lua +++ /dev/null @@ -1,155 +0,0 @@ -local utils = require('blink.cmp.sources.lib.utils') -local async = require('blink.cmp.lib.async') - ---- @class blink.cmp.SourcesContext ---- @field id number ---- @field sources table ---- @field active_request blink.cmp.Task | nil ---- @field queued_request_context blink.cmp.Context | nil ---- @field cached_responses table | nil ---- @field on_completions_callback fun(context: blink.cmp.Context, enabled_sources: string[], responses: table) ---- ---- @field new fun(context: blink.cmp.Context, sources: table, on_completions_callback: fun(context: blink.cmp.Context, items: table)): blink.cmp.SourcesContext ---- @field get_sources fun(self: blink.cmp.SourcesContext): string[] ---- @field get_cached_completions fun(self: blink.cmp.SourcesContext): table | nil ---- @field get_completions fun(self: blink.cmp.SourcesContext, context: blink.cmp.Context) ---- @field get_completions_for_sources fun(self: blink.cmp.SourcesContext, sources: table, context: blink.cmp.Context): blink.cmp.Task ---- @field get_completions_with_fallbacks fun(self: blink.cmp.SourcesContext, context: blink.cmp.Context, source: blink.cmp.SourceProvider, sources: table): blink.cmp.Task ---- @field destroy fun(self: blink.cmp.SourcesContext) - ---- @type blink.cmp.SourcesContext ---- @diagnostic disable-next-line: missing-fields -local sources_context = {} - -function sources_context.new(context, sources, on_completions_callback) - local self = setmetatable({}, { __index = sources_context }) - self.id = context.id - self.sources = sources - - self.active_request = nil - self.queued_request_context = nil - self.on_completions_callback = on_completions_callback - - return self -end - -function sources_context:get_sources() return vim.tbl_keys(self.sources) end - -function sources_context:get_cached_completions() return self.cached_responses end - -function sources_context:get_completions(context) - assert(context.id == self.id, 'Requested completions on a sources context with a different context ID') - - if self.active_request ~= nil and self.active_request.status == async.STATUS.RUNNING then - self.queued_request_context = context - return - end - - -- Create a task to get the completions, send responses upstream - -- and run the queued request, if it exists - self.active_request = self:get_completions_for_sources(self.sources, context):map(function(responses) - self.cached_responses = responses - --- @cast responses table - self.active_request = nil - - -- only send upstream if the responses contain something new - local is_cached = true - for _, response in pairs(responses) do - is_cached = is_cached and (response.is_cached or false) - end - if not is_cached then self.on_completions_callback(context, self:get_sources(), responses) end - - -- run the queued request, if it exists - if self.queued_request_context ~= nil then - local queued_context = self.queued_request_context - self.queued_request_context = nil - self:get_completions(queued_context) - end - end) -end - -function sources_context:get_completions_for_sources(sources, context) - local enabled_sources = vim.tbl_keys(sources) - --- @type blink.cmp.SourceProvider[] - local non_fallback_sources = vim.tbl_filter(function(source) - local fallbacks = source.config.fallback_for and source.config.fallback_for(context, enabled_sources) or {} - fallbacks = vim.tbl_filter(function(fallback) return sources[fallback] end, fallbacks) - return #fallbacks == 0 - end, vim.tbl_values(sources)) - - -- get completions for each non-fallback source - local tasks = vim.tbl_map(function(source) - -- the source indicates we should refetch when this character is typed - local trigger_character = context.trigger.character - and vim.tbl_contains(source:get_trigger_characters(), context.trigger.character) - - -- The TriggerForIncompleteCompletions kind is handled by the source provider itself - local source_context = require('blink.cmp.lib.utils').shallow_copy(context) - source_context.trigger = trigger_character - and { kind = vim.lsp.protocol.CompletionTriggerKind.TriggerCharacter, character = context.trigger.character } - or { kind = vim.lsp.protocol.CompletionTriggerKind.Invoked } - - return self:get_completions_with_fallbacks(source_context, source, sources) - end, non_fallback_sources) - - -- wait for all the tasks to complete - return async.task - .await_all(tasks) - :map(function(tasks_results) - local responses = {} - for idx, task_result in ipairs(tasks_results) do - if task_result.status == async.STATUS.COMPLETED then - --- @type blink.cmp.SourceProvider - local source = vim.tbl_values(non_fallback_sources)[idx] - responses[source.id] = task_result.result - end - end - return responses - end) - :catch(function(err) - vim.print('failed to get completions for sources with error: ' .. err) - return {} - end) -end - ---- Runs the source's get_completions function, falling back to other sources ---- with fallback_for = { source.name } if the source returns no completion items ---- TODO: When a source has multiple fallbacks, we may end up with duplicate completion items -function sources_context:get_completions_with_fallbacks(context, source, sources) - local enabled_sources = vim.tbl_keys(sources) - local fallback_sources = vim.tbl_filter( - function(fallback_source) - return fallback_source.id ~= source.id - and fallback_source.config.fallback_for ~= nil - and vim.tbl_contains(fallback_source.config.fallback_for(context), source.id) - end, - vim.tbl_values(sources) - ) - - return source:get_completions(context, enabled_sources):map(function(response) - -- source returned completions, no need to fallback - if #response.items > 0 or #fallback_sources == 0 then return response end - - -- run fallbacks - return async.task - .await_all(vim.tbl_map(function(fallback) return fallback:get_completions(context) end, fallback_sources)) - :map(function(task_results) - local successful_task_results = vim.tbl_filter( - function(task_result) return task_result.status == async.STATUS.COMPLETED end, - task_results - ) - local fallback_responses = vim.tbl_map( - function(task_result) return task_result.result end, - successful_task_results - ) - return utils.concat_responses(fallback_responses) - end) - end) -end - -function sources_context:destroy() - self.on_completions_callback = function() end - if self.active_request ~= nil then self.active_request:cancel() end -end - -return sources_context diff --git a/lua/blink/cmp/sources/lib/init.lua b/lua/blink/cmp/sources/lib/init.lua index 0a72277f..f6c111ca 100644 --- a/lua/blink/cmp/sources/lib/init.lua +++ b/lua/blink/cmp/sources/lib/init.lua @@ -2,20 +2,22 @@ local async = require('blink.cmp.lib.async') local config = require('blink.cmp.config') --- @class blink.cmp.Sources ---- @field current_context blink.cmp.SourcesContext | nil +--- @field completions_queue blink.cmp.SourcesContext | nil --- @field current_signature_help blink.cmp.Task | nil --- @field sources_registered boolean --- @field providers table --- @field completions_emitter blink.cmp.EventEmitter --- ---- @field get_enabled_providers fun(context?: blink.cmp.Context): table +--- @field get_all_providers fun(): blink.cmp.SourceProvider[] +--- @field get_enabled_provider_ids fun(): string[] +--- @field get_enabled_providers fun(): table --- @field get_trigger_characters fun(): string[] --- ---- @field emit_completions fun(context: blink.cmp.Context, enabled_sources: table, responses: table) +--- @field emit_completions fun(context: blink.cmp.Context, responses: table) --- @field request_completions fun(context: blink.cmp.Context) --- @field cancel_completions fun() ---- @field listen_on_completions fun(callback: fun(context: blink.cmp.Context, items: blink.cmp.CompletionItem[])) --- @field apply_max_items_for_completions fun(context: blink.cmp.Context, items: blink.cmp.CompletionItem[]): blink.cmp.CompletionItem[] +--- @field listen_on_completions fun(callback: fun(context: blink.cmp.Context, items: blink.cmp.CompletionItem[])) --- @field resolve fun(item: blink.cmp.CompletionItem): blink.cmp.Task --- @field execute fun(context: blink.cmp.Context, item: blink.cmp.CompletionItem): blink.cmp.Task --- @@ -28,105 +30,112 @@ local config = require('blink.cmp.config') --- @class blink.cmp.SourceCompletionsEvent --- @field context blink.cmp.Context ---- @field items blink.cmp.CompletionItem[] +--- @field items table --- @type blink.cmp.Sources --- @diagnostic disable-next-line: missing-fields local sources = { - current_context = nil, + completions_queue = nil, providers = {}, completions_emitter = require('blink.cmp.lib.event_emitter').new('source_completions', 'BlinkCmpSourceCompletions'), } -function sources.get_enabled_providers(context) - local mode_providers = type(config.sources.completion.enabled_providers) == 'function' - and config.sources.completion.enabled_providers(context) - or config.sources.completion.enabled_providers - --- @cast mode_providers string[] - - for _, provider in ipairs(mode_providers) do - assert( - sources.providers[provider] ~= nil or config.sources.providers[provider] ~= nil, - 'Requested provider "' - .. provider - .. '" has not been configured. Available providers: ' - .. vim.fn.join(vim.tbl_keys(sources.providers), ', ') - ) - -- initialize the provider if it hasn't been initialized yet - if not sources.providers[provider] then - sources.providers[provider] = - require('blink.cmp.sources.lib.provider').new(provider, config.sources.providers[provider]) - end +function sources.get_all_providers() + local providers = {} + for provider_id, _ in pairs(config.sources.providers) do + providers[provider_id] = sources.get_provider_by_id(provider_id) end + return providers +end + +function sources.get_enabled_provider_ids() + local enabled_providers = config.sources.per_filetype[vim.bo.filetype] or config.sources.default + if type(enabled_providers) == 'function' then return enabled_providers() end + --- @cast enabled_providers string[] + return enabled_providers +end + +function sources.get_enabled_providers() + local mode_providers = sources.get_enabled_provider_ids() --- @type table local providers = {} - for key, provider in pairs(sources.providers) do - if vim.tbl_contains(mode_providers, key) and provider:enabled(context) then providers[key] = provider end + for _, provider_id in ipairs(mode_providers) do + local provider = sources.get_provider_by_id(provider_id) + if provider:enabled() then providers[provider_id] = sources.get_provider_by_id(provider_id) end end return providers end +function sources.get_provider_by_id(provider_id) + assert( + sources.providers[provider_id] ~= nil or config.sources.providers[provider_id] ~= nil, + 'Requested provider "' + .. provider_id + .. '" has not been configured. Available providers: ' + .. vim.fn.join(vim.tbl_keys(sources.providers), ', ') + ) + + -- initialize the provider if it hasn't been initialized yet + if not sources.providers[provider_id] then + local provider_config = config.sources.providers[provider_id] + sources.providers[provider_id] = require('blink.cmp.sources.lib.provider').new(provider_id, provider_config) + end + + return sources.providers[provider_id] +end + --- Completion --- function sources.get_trigger_characters() local providers = sources.get_enabled_providers() local trigger_characters = {} - for _, source in pairs(providers) do - vim.list_extend(trigger_characters, source:get_trigger_characters()) + for _, provider in pairs(providers) do + vim.list_extend(trigger_characters, provider:get_trigger_characters()) end return trigger_characters end -function sources.emit_completions(context, enabled_sources, responses) - local items = {} - for id, response in pairs(responses) do - if sources.providers[id]:should_show_items(context, enabled_sources, response.items) then - vim.list_extend(items, response.items) - end +function sources.emit_completions(context, _items_by_provider) + local items_by_provider = {} + for id, items in pairs(_items_by_provider) do + if sources.providers[id]:should_show_items(context, items) then items_by_provider[id] = items end end - sources.completions_emitter:emit({ context = context, items = items }) + sources.completions_emitter:emit({ context = context, items = items_by_provider }) end function sources.request_completions(context) -- create a new context if the id changed or if we haven't created one yet - local is_new_context = sources.current_context == nil or context.id ~= sources.current_context.id - if is_new_context then - if sources.current_context ~= nil then sources.current_context:destroy() end - sources.current_context = require('blink.cmp.sources.lib.context').new( - context, - sources.get_enabled_providers(context), - sources.emit_completions - ) + if sources.completions_queue == nil or context.id ~= sources.completions_queue.id then + if sources.completions_queue ~= nil then sources.completions_queue:destroy() end + sources.completions_queue = + require('blink.cmp.sources.lib.queue').new(context, sources.get_all_providers(), sources.emit_completions) -- send cached completions if they exist to immediately trigger updates - elseif sources.current_context:get_cached_completions() ~= nil then + elseif sources.completions_queue:get_cached_completions() ~= nil then sources.emit_completions( context, - sources.current_context:get_sources(), --- @diagnostic disable-next-line: param-type-mismatch - sources.current_context:get_cached_completions() + sources.completions_queue:get_cached_completions() ) end - sources.current_context:get_completions(context) + sources.completions_queue:get_completions(context) end function sources.cancel_completions() - if sources.current_context ~= nil then - sources.current_context:destroy() - sources.current_context = nil + if sources.completions_queue ~= nil then + sources.completions_queue:destroy() + sources.completions_queue = nil end end --- Limits the number of items per source as configured function sources.apply_max_items_for_completions(context, items) - local enabled_sources = sources.get_enabled_providers(context) - -- get the configured max items for each source local total_items_for_sources = {} local max_items_for_sources = {} for id, source in pairs(sources.providers) do - max_items_for_sources[id] = source.config.max_items(context, enabled_sources, items) + max_items_for_sources[id] = source.config.max_items(context, items) total_items_for_sources[id] = 0 end @@ -199,13 +208,8 @@ function sources.get_signature_help(context, callback) table.insert(tasks, source:get_signature_help(context)) end - sources.current_signature_help = async.task.await_all(tasks):map(function(tasks_results) - local signature_helps = {} - for _, task_result in ipairs(tasks_results) do - if task_result.status == async.STATUS.COMPLETED and task_result.result ~= nil then - table.insert(signature_helps, task_result.result) - end - end + sources.current_signature_help = async.task.await_all(tasks):map(function(signature_helps) + signature_helps = vim.tbl_filter(function(signature_help) return signature_help ~= nil end, signature_helps) callback(signature_helps[1]) end) end diff --git a/lua/blink/cmp/sources/lib/provider/config.lua b/lua/blink/cmp/sources/lib/provider/config.lua index a3e4c100..6eeea675 100644 --- a/lua/blink/cmp/sources/lib/provider/config.lua +++ b/lua/blink/cmp/sources/lib/provider/config.lua @@ -2,13 +2,15 @@ --- @field new fun(config: blink.cmp.SourceProviderConfig): blink.cmp.SourceProviderConfigWrapper --- @field name string --- @field module string ---- @field enabled fun(ctx: blink.cmp.Context): boolean +--- @field enabled fun(): boolean +--- @field async fun(ctx: blink.cmp.Context): boolean +--- @field timeout_ms fun(ctx: blink.cmp.Context): number --- @field transform_items fun(ctx: blink.cmp.Context, items: blink.cmp.CompletionItem[]): blink.cmp.CompletionItem[] --- @field should_show_items fun(ctx: blink.cmp.Context, items: blink.cmp.CompletionItem[]): boolean ---- @field max_items? fun(ctx: blink.cmp.Context, enabled_sources: string[], items: blink.cmp.CompletionItem[]): number ---- @field min_keyword_length fun(ctx: blink.cmp.Context, enabled_sources: string[]): number ---- @field fallback_for fun(ctx: blink.cmp.Context, enabled_sources: string[]): string[] ---- @field score_offset fun(ctx: blink.cmp.Context, enabled_sources: string[]): number +--- @field max_items? fun(ctx: blink.cmp.Context, items: blink.cmp.CompletionItem[]): number +--- @field min_keyword_length fun(ctx: blink.cmp.Context): number +--- @field fallbacks fun(ctx: blink.cmp.Context): string[] +--- @field score_offset fun(ctx: blink.cmp.Context): number --- @class blink.cmp.SourceProviderConfig --- @diagnostic disable-next-line: missing-fields @@ -27,11 +29,13 @@ function wrapper.new(config) self.name = config.name self.module = config.module self.enabled = call_or_get(config.enabled, true) + self.async = call_or_get(config.async, false) + self.timeout_ms = call_or_get(config.timeout, 2000) self.transform_items = config.transform_items or function(_, items) return items end self.should_show_items = call_or_get(config.should_show_items, true) self.max_items = call_or_get(config.max_items, nil) self.min_keyword_length = call_or_get(config.min_keyword_length, 0) - self.fallback_for = call_or_get(config.fallback_for, {}) + self.fallbacks = call_or_get(config.fallbacks, {}) self.score_offset = call_or_get(config.score_offset, 0) return self end diff --git a/lua/blink/cmp/sources/lib/provider/init.lua b/lua/blink/cmp/sources/lib/provider/init.lua index ca431ef3..f378dde6 100644 --- a/lua/blink/cmp/sources/lib/provider/init.lua +++ b/lua/blink/cmp/sources/lib/provider/init.lua @@ -4,14 +4,14 @@ --- @field name string --- @field config blink.cmp.SourceProviderConfigWrapper --- @field module blink.cmp.Source ---- @field last_response blink.cmp.CompletionResponse | nil +--- @field list blink.cmp.SourceProviderList | nil --- @field resolve_tasks table --- --- @field new fun(id: string, config: blink.cmp.SourceProviderConfig): blink.cmp.SourceProvider ---- @field enabled fun(self: blink.cmp.SourceProvider, context: blink.cmp.Context): boolean +--- @field enabled fun(self: blink.cmp.SourceProvider): boolean --- @field get_trigger_characters fun(self: blink.cmp.SourceProvider): string[] ---- @field get_completions fun(self: blink.cmp.SourceProvider, context: blink.cmp.Context, enabled_sources: string[]): blink.cmp.Task ---- @field should_show_items fun(self: blink.cmp.SourceProvider, context: blink.cmp.Context, enabled_sources: string[], response: blink.cmp.CompletionResponse): boolean +--- @field get_completions fun(self: blink.cmp.SourceProvider, context: blink.cmp.Context, on_items: fun(items: blink.cmp.CompletionItem[], is_cached: boolean)) +--- @field should_show_items fun(self: blink.cmp.SourceProvider, context: blink.cmp.Context, items: blink.cmp.CompletionItem[]): boolean --- @field resolve fun(self: blink.cmp.SourceProvider, item: blink.cmp.CompletionItem): blink.cmp.Task --- @field execute fun(self: blink.cmp.SourceProvider, context: blink.cmp.Context, item: blink.cmp.CompletionItem, callback: fun()): blink.cmp.Task --- @field get_signature_help_trigger_characters fun(self: blink.cmp.SourceProvider): { trigger_characters: string[], retrigger_characters: string[] } @@ -22,7 +22,6 @@ --- @diagnostic disable-next-line: missing-fields local source = {} -local utils = require('blink.cmp.sources.lib.utils') local async = require('blink.cmp.lib.async') function source.new(id, config) @@ -37,19 +36,19 @@ function source.new(id, config) config.override ) self.config = require('blink.cmp.sources.lib.provider.config').new(config) - self.last_response = nil + self.list = nil self.resolve_tasks = {} return self end -function source:enabled(context) +function source:enabled() -- user defined - if not self.config.enabled(context) then return false end + if not self.config.enabled() then return false end -- source defined if self.module.enabled == nil then return true end - return self.module:enabled(context) + return self.module:enabled() end --- Completion --- @@ -59,58 +58,48 @@ function source:get_trigger_characters() return self.module:get_trigger_characters() end -function source:get_completions(context, enabled_sources) - -- Return the previous successful completions if the context is the same +function source:get_completions(context, on_items) + -- return the previous successful completions if the context is the same -- and the data doesn't need to be updated - if self.last_response ~= nil and self.last_response.context.id == context.id then - if utils.should_run_request(context, self.last_response) == false then - return async.task.new( - function(resolve) resolve(require('blink.cmp.lib.utils').shallow_copy(self.last_response)) end - ) - end + -- or if the list is async, since we don't want to cause a flash of no items + if self.list ~= nil and self.list:is_valid_for_context(context) then + self.list:set_on_items(on_items) + self.list:emit(true) + return end - return async.task - .new(function(resolve) - if self.module.get_completions == nil then return resolve() end - return self.module:get_completions(context, resolve) - end) - :map(function(response) - if response == nil then response = { is_incomplete_forward = true, is_incomplete_backward = true, items = {} } end - response.context = context - - -- add non-lsp metadata - local source_score_offset = self.config.score_offset(context, enabled_sources) or 0 - for _, item in ipairs(response.items) do - item.score_offset = (item.score_offset or 0) + source_score_offset - item.cursor_column = context.cursor[2] - item.source_id = self.id - item.source_name = self.name - end - - -- if the user provided a transform_items function, run it - if self.config.transform_items ~= nil then - response.items = self.config.transform_items(context, response.items) - end - - self.last_response = require('blink.cmp.lib.utils').shallow_copy(response) - self.last_response.is_cached = true - return response - end) - :catch(function(err) - vim.print('failed to get completions with error: ' .. err) - return { is_incomplete_forward = false, is_incomplete_backward = false, items = {} } - end) + -- the source indicates we should refetch when this character is typed + local trigger_character = context.trigger.character + and vim.tbl_contains(self:get_trigger_characters(), context.trigger.character) + + -- The TriggerForIncompleteCompletions kind is handled by the source provider itself + local source_context = require('blink.cmp.lib.utils').shallow_copy(context) + source_context.trigger = trigger_character + and { kind = vim.lsp.protocol.CompletionTriggerKind.TriggerCharacter, character = context.trigger.character } + or { kind = vim.lsp.protocol.CompletionTriggerKind.Invoked } + + local async_initial_items = self.list ~= nil and self.list.context.id == context.id and self.list.items or {} + if self.list ~= nil then self.list:destroy() end + + self.list = require('blink.cmp.sources.lib.provider.list').new( + self, + context, + on_items, + -- HACK: if the source is async, we're not reusing the previous list and the response was marked as incomplete, + -- the user will see a flash of no items from the provider, since the list emits immediately. So we hack around + -- this for now + { async_initial_items = async_initial_items } + ) end -function source:should_show_items(context, enabled_sources, response) +function source:should_show_items(context, items) -- if keyword length is configured, check if the context is long enough - local min_keyword_length = self.config.min_keyword_length(context, enabled_sources) + local min_keyword_length = self.config.min_keyword_length(context) local current_keyword_length = context.bounds.length if current_keyword_length < min_keyword_length then return false end if self.config.should_show_items == nil then return true end - return self.config.should_show_items(context, response.items) + return self.config.should_show_items(context, items) end --- Resolve --- diff --git a/lua/blink/cmp/sources/lib/provider/list.lua b/lua/blink/cmp/sources/lib/provider/list.lua new file mode 100644 index 00000000..ebce2b90 --- /dev/null +++ b/lua/blink/cmp/sources/lib/provider/list.lua @@ -0,0 +1,127 @@ +--- @class blink.cmp.SourceProviderList +--- @field provider blink.cmp.SourceProvider +--- @field context blink.cmp.Context +--- @field items blink.cmp.CompletionItem[] +--- @field on_items fun(items: blink.cmp.CompletionItem[], is_cached: boolean) +--- @field has_completed boolean +--- @field is_incomplete_backward boolean +--- @field is_incomplete_forward boolean +--- @field cancel_completions? fun(): nil +--- +--- @field new fun(provider: blink.cmp.SourceProvider,context: blink.cmp.Context, on_items: fun(items: blink.cmp.CompletionItem[], is_cached: boolean), opts: blink.cmp.SourceProviderListOpts): blink.cmp.SourceProviderList +--- @field append fun(self: blink.cmp.SourceProviderList, response: blink.cmp.CompletionResponse) +--- @field emit fun(self: blink.cmp.SourceProviderList, is_cached?: boolean) +--- @field destroy fun(self: blink.cmp.SourceProviderList): nil +--- @field set_on_items fun(self: blink.cmp.SourceProviderList, on_items: fun(items: blink.cmp.CompletionItem[], is_cached: boolean)) +--- @field is_valid_for_context fun(self: blink.cmp.SourceProviderList, context: blink.cmp.Context): boolean +--- +--- @class blink.cmp.SourceProviderListOpts +--- @field async_initial_items blink.cmp.CompletionItem[] + +--- @type blink.cmp.SourceProviderList +--- @diagnostic disable-next-line: missing-fields +local list = {} + +function list.new(provider, context, on_items, opts) + --- @type blink.cmp.SourceProviderList + local self = setmetatable({ + provider = provider, + context = context, + items = opts.async_initial_items, + on_items = on_items, + + has_completed = false, + is_incomplete_backward = true, + is_incomplete_forward = true, + }, { __index = list }) + + -- Immediately fetch completions + local default_response = { + is_incomplete_forward = true, + is_incomplete_backward = true, + items = {}, + } + if self.provider.module.get_completions == nil then + self:append(default_response) + else + self.cancel_completions = self.provider.module:get_completions( + self.context, + function(response) self:append(response or default_response) end + ) + end + + -- if async, immediately send the default response/initial items + local is_async = self.provider.config.async(self.context) + if self.provider.config.async(self.context) and not self.has_completed then self:emit() end + + -- if not async and timeout is set, send the default response after the timeout + local timeout_ms = self.provider.config.timeout_ms(self.context) + if not is_async and timeout_ms > 0 then + vim.defer_fn(function() + if not self.has_completed then self:append(default_response) end + end, timeout_ms) + end + + return self +end + +function list:append(response) + if not self.has_completed then + self.has_completed = true + self.is_incomplete_backward = response.is_incomplete_backward + self.is_incomplete_forward = response.is_incomplete_forward + self.items = {} + end + + -- add non-lsp metadata + local source_score_offset = self.provider.config.score_offset(self.context) or 0 + for _, item in ipairs(response.items) do + item.score_offset = (item.score_offset or 0) + source_score_offset + item.cursor_column = self.context.cursor[2] + item.source_id = self.provider.id + item.source_name = self.provider.name + end + + -- combine with existing items + local new_items = {} + vim.list_extend(new_items, self.items) + vim.list_extend(new_items, response.items) + self.items = new_items + + -- if the user provided a transform_items function, run it + if self.provider.config.transform_items ~= nil then + self.items = self.provider.config.transform_items(self.context, self.items) + end + + self:emit() +end + +function list:emit(is_cached) + if is_cached == nil then is_cached = false end + self.on_items(self.items, is_cached) +end + +function list:destroy() + if self.cancel_completions ~= nil then self.cancel_completions() end + self.on_items = function() end +end + +function list:set_on_items(on_items) self.on_items = on_items end + +function list:is_valid_for_context(new_context) + if self.context.id ~= new_context.id then return false end + + -- get the text for the current and queued context + local old_context_query = self.context.line:sub(self.context.bounds.start_col, self.context.cursor[2]) + local new_context_query = new_context.line:sub(new_context.bounds.start_col, new_context.cursor[2]) + + -- check if the texts are overlapping + local is_before = vim.startswith(old_context_query, new_context_query) + local is_after = vim.startswith(new_context_query, old_context_query) + + return (is_before and not self.is_incomplete_backward) + or (is_after and not self.is_incomplete_forward) + or (is_after == is_before and not (self.is_incomplete_backward or self.is_incomplete_forward)) +end + +return list diff --git a/lua/blink/cmp/sources/lib/queue.lua b/lua/blink/cmp/sources/lib/queue.lua new file mode 100644 index 00000000..5f2b31e5 --- /dev/null +++ b/lua/blink/cmp/sources/lib/queue.lua @@ -0,0 +1,68 @@ +local async = require('blink.cmp.lib.async') + +--- @class blink.cmp.SourcesQueue +--- @field id number +--- @field providers table +--- @field request blink.cmp.Task | nil +--- @field queued_request_context blink.cmp.Context | nil +--- @field cached_items_by_provider table | nil +--- @field on_completions_callback fun(context: blink.cmp.Context, responses: table) +--- +--- @field new fun(context: blink.cmp.Context, providers: table, on_completions_callback: fun(context: blink.cmp.Context, responses: table)): blink.cmp.SourcesContext +--- @field get_cached_completions fun(self: blink.cmp.SourcesQueue): table | nil +--- @field get_completions fun(self: blink.cmp.SourcesQueue, context: blink.cmp.Context) +--- @field destroy fun(self: blink.cmp.SourcesQueue) + +--- @type blink.cmp.SourcesQueue +--- @diagnostic disable-next-line: missing-fields +local queue = {} + +function queue.new(context, providers, on_completions_callback) + local self = setmetatable({}, { __index = queue }) + self.id = context.id + self.providers = providers + + self.request = nil + self.queued_request_context = nil + self.on_completions_callback = on_completions_callback + + return self +end + +function queue:get_cached_completions() return self.cached_items_by_provider end + +function queue:get_completions(context) + assert(context.id == self.id, 'Requested completions on a sources context with a different context ID') + + if self.request ~= nil then + if self.request.status == async.STATUS.RUNNING then + self.queued_request_context = context + return + else + self.request:cancel() + end + end + + -- Create a task to get the completions, send responses upstream + -- and run the queued request, if it exists + local tree = require('blink.cmp.sources.lib.tree').new(context, vim.tbl_values(self.providers)) + self.request = tree:get_completions(context, function(items_by_provider) + self.cached_items_by_provider = items_by_provider + self.on_completions_callback(context, items_by_provider) + + -- run the queued request, if it exists + local queued_context = self.queued_request_context + if queued_context ~= nil then + self.queued_request_context = nil + self.request:cancel() + self:get_completions(queued_context) + end + end) +end + +function queue:destroy() + self.on_completions_callback = function() end + if self.request ~= nil then self.request:cancel() end +end + +return queue diff --git a/lua/blink/cmp/sources/lib/tree.lua b/lua/blink/cmp/sources/lib/tree.lua new file mode 100644 index 00000000..c89d04d0 --- /dev/null +++ b/lua/blink/cmp/sources/lib/tree.lua @@ -0,0 +1,168 @@ +--- @class blink.cmp.SourceTreeNode +--- @field id string +--- @field source blink.cmp.SourceProvider +--- @field dependencies blink.cmp.SourceTreeNode[] +--- @field dependents blink.cmp.SourceTreeNode[] + +--- @class blink.cmp.SourceTree +--- @field nodes blink.cmp.SourceTreeNode[] +--- @field new fun(context: blink.cmp.Context, all_sources: blink.cmp.SourceProvider[]): blink.cmp.SourceTree +--- @field get_completions fun(self: blink.cmp.SourceTree, context: blink.cmp.Context, on_items_by_provider: fun(items_by_provider: table)): blink.cmp.Task +--- @field emit_completions fun(self: blink.cmp.SourceTree, items_by_provider: table, on_items_by_provider: fun(items_by_provider: table)): nil +--- @field get_top_level_nodes fun(self: blink.cmp.SourceTree): blink.cmp.SourceTreeNode[] +--- @field detect_cycle fun(node: blink.cmp.SourceTreeNode, visited?: table, path?: table): boolean + +local utils = require('blink.cmp.lib.utils') +local async = require('blink.cmp.lib.async') + +--- @type blink.cmp.SourceTree +--- @diagnostic disable-next-line: missing-fields +local tree = {} + +--- @param context blink.cmp.Context +--- @param all_sources blink.cmp.SourceProvider[] +function tree.new(context, all_sources) + -- only include enabled sources for the given context + local sources = vim.tbl_filter( + function(source) return vim.tbl_contains(context.providers, source.id) and source:enabled(context) end, + all_sources + ) + local source_ids = vim.tbl_map(function(source) return source.id end, sources) + + -- create a node for each source + local nodes = vim.tbl_map( + function(source) return { id = source.id, source = source, dependencies = {}, dependents = {} } end, + sources + ) + + -- build the tree + for idx, source in ipairs(sources) do + local node = nodes[idx] + for _, fallback_source_id in ipairs(source.config.fallbacks(context, source_ids)) do + local fallback_node = nodes[utils.index_of(source_ids, fallback_source_id)] + if fallback_node ~= nil then + table.insert(node.dependents, fallback_node) + table.insert(fallback_node.dependencies, node) + end + end + end + + -- circular dependency check + for _, node in ipairs(nodes) do + tree.detect_cycle(node) + end + + return setmetatable({ nodes = nodes }, { __index = tree }) +end + +function tree:get_completions(context, on_items_by_provider) + local should_push_upstream = false + local items_by_provider = {} + local is_all_cached = true + local nodes_falling_back = {} + + --- @param node blink.cmp.SourceTreeNode + local function get_completions_for_node(node) + -- check that all the dependencies have been triggered, and are falling back + for _, dependency in ipairs(node.dependencies) do + if not nodes_falling_back[dependency.id] then return async.task.empty() end + end + + return async.task.new(function(resolve, reject) + return node.source:get_completions(context, function(items, is_cached) + items_by_provider[node.id] = items + is_all_cached = is_all_cached and is_cached + + if should_push_upstream then self:emit_completions(items_by_provider, on_items_by_provider) end + if #items ~= 0 then return resolve() end + + -- run dependents if the source returned 0 items + nodes_falling_back[node.id] = true + local tasks = vim.tbl_map(function(dependent) return get_completions_for_node(dependent) end, node.dependents) + async.task.await_all(tasks):map(resolve):catch(reject) + end) + end) + end + + -- run the top level nodes and let them fall back to their dependents if needed + local tasks = vim.tbl_map(function(node) return get_completions_for_node(node) end, self:get_top_level_nodes()) + return async.task + .await_all(tasks) + :map(function() + should_push_upstream = true + + -- if atleast one of the results wasn't cached, emit the results + if not is_all_cached then self:emit_completions(items_by_provider, on_items_by_provider) end + end) + :catch(function(err) vim.print('failed to get completions with error: ' .. err) end) +end + +function tree:emit_completions(items_by_provider, on_items_by_provider) + local nodes_falling_back = {} + local final_items_by_provider = {} + + local add_node_items + add_node_items = function(node) + for _, dependency in ipairs(node.dependencies) do + if not nodes_falling_back[dependency.id] then return end + end + local items = items_by_provider[node.id] + if items ~= nil and #items > 0 then + final_items_by_provider[node.id] = items + else + nodes_falling_back[node.id] = true + for _, dependent in ipairs(node.dependents) do + add_node_items(dependent) + end + end + end + + for _, node in ipairs(self:get_top_level_nodes()) do + add_node_items(node) + end + + on_items_by_provider(final_items_by_provider) +end + +--- Internal --- + +function tree:get_top_level_nodes() + local top_level_nodes = {} + for _, node in ipairs(self.nodes) do + if #node.dependencies == 0 then table.insert(top_level_nodes, node) end + end + return top_level_nodes +end + +--- Helper function to detect cycles using DFS +--- @param node blink.cmp.SourceTreeNode +--- @param visited? table +--- @param path? table +--- @return boolean +function tree.detect_cycle(node, visited, path) + visited = visited or {} + path = path or {} + + if path[node.id] then + -- Found a cycle - construct the cycle path for error message + local cycle = { node.id } + for id, _ in pairs(path) do + table.insert(cycle, id) + end + error('Circular dependency detected: ' .. table.concat(cycle, ' -> ')) + end + + if visited[node.id] then return false end + + visited[node.id] = true + path[node.id] = true + + for _, dependent in ipairs(node.dependents) do + if tree.detect_cycle(dependent, visited, path) then return true end + end + + path[node.id] = nil + return false +end + +return tree diff --git a/lua/blink/cmp/sources/lib/types.lua b/lua/blink/cmp/sources/lib/types.lua index 50ab2a32..e9900c03 100644 --- a/lua/blink/cmp/sources/lib/types.lua +++ b/lua/blink/cmp/sources/lib/types.lua @@ -5,12 +5,11 @@ --- @class blink.cmp.CompletionResponse --- @field is_incomplete_forward boolean --- @field is_incomplete_backward boolean ---- @field context blink.cmp.Context --- @field items blink.cmp.CompletionItem[] --- @class blink.cmp.Source --- @field new fun(opts: table, config: blink.cmp.SourceProviderConfig): blink.cmp.Source ---- @field enabled? fun(self: blink.cmp.Source, context: blink.cmp.Context): boolean +--- @field enabled? fun(self: blink.cmp.Source): boolean --- @field get_trigger_characters? fun(self: blink.cmp.Source): string[] --- @field get_completions? fun(self: blink.cmp.Source, context: blink.cmp.Context, callback: fun(response?: blink.cmp.CompletionResponse)): (fun(): nil) | nil --- @field should_show_completions? fun(self: blink.cmp.Source, context: blink.cmp.Context, response: blink.cmp.CompletionResponse): boolean diff --git a/lua/blink/cmp/sources/lib/utils.lua b/lua/blink/cmp/sources/lib/utils.lua index d67f9ced..efa5901f 100644 --- a/lua/blink/cmp/sources/lib/utils.lua +++ b/lua/blink/cmp/sources/lib/utils.lua @@ -1,53 +1,5 @@ local utils = {} ---- Checks if a request should be made, based on the previous response/context ---- and the new context ---- ---- @param new_context blink.cmp.Context ---- @param response blink.cmp.CompletionResponse ---- ---- @return false | 'forward' | 'backward' | 'unknown' -function utils.should_run_request(new_context, response) - local old_context = response.context - -- get the text for the current and queued context - local old_context_query = old_context.line:sub(old_context.bounds.start_col, old_context.cursor[2]) - local new_context_query = new_context.line:sub(new_context.bounds.start_col, new_context.cursor[2]) - - -- check if the texts are overlapping - local is_before = vim.startswith(old_context_query, new_context_query) - local is_after = vim.startswith(new_context_query, old_context_query) - - if is_before and response.is_incomplete_backward then return 'forward' end - if is_after and response.is_incomplete_forward then return 'backward' end - if (is_after == is_before) and (response.is_incomplete_backward or response.is_incomplete_forward) then - return 'unknown' - end - return false -end - ---- @param responses blink.cmp.CompletionResponse[] ---- @return blink.cmp.CompletionResponse -function utils.concat_responses(responses) - local is_cached = true - local is_incomplete_forward = false - local is_incomplete_backward = false - local items = {} - - for _, response in ipairs(responses) do - is_cached = is_cached and response.is_cached - is_incomplete_forward = is_incomplete_forward or response.is_incomplete_forward - is_incomplete_backward = is_incomplete_backward or response.is_incomplete_backward - vim.list_extend(items, response.items) - end - - return { - is_cached = is_cached, - is_incomplete_forward = is_incomplete_forward, - is_incomplete_backward = is_incomplete_backward, - items = items, - } -end - --- @param item blink.cmp.CompletionItem --- @return lsp.CompletionItem function utils.blink_item_to_lsp_item(item) diff --git a/lua/blink/cmp/sources/lsp.lua b/lua/blink/cmp/sources/lsp.lua index 25cdb621..0fcd4ae9 100644 --- a/lua/blink/cmp/sources/lsp.lua +++ b/lua/blink/cmp/sources/lsp.lua @@ -1,3 +1,12 @@ +local known_defaults = { + 'commitCharacters', + 'editRange', + 'insertTextFormat', + 'insertTextMode', + 'data', +} +local CompletionTriggerKind = vim.lsp.protocol.CompletionTriggerKind + --- @type blink.cmp.Source --- @diagnostic disable-next-line: missing-fields local lsp = {} @@ -40,6 +49,9 @@ function lsp:get_trigger_characters() return trigger_characters end +--- @param capability string +--- @param filter? table +--- @return vim.lsp.Client[] function lsp:get_clients_with_capability(capability, filter) local clients = {} for _, client in pairs(vim.lsp.get_clients(filter)) do @@ -50,121 +62,59 @@ function lsp:get_clients_with_capability(capability, filter) end function lsp:get_completions(context, callback) - -- TODO: should make separate LSP requests to return results earlier, in the case of slow LSPs + local clients = + self:get_clients_with_capability('completionProvider', { bufnr = 0, method = 'textDocument/completion' }) - -- no providers with completion support - if not self:has_capability('completionProvider') or not self:has_method('textDocument/completion') then + -- no clients with completion support + if #clients == 0 then callback({ is_incomplete_forward = false, is_incomplete_backward = false, items = {} }) return function() end end - -- TODO: offset encoding is global but should be per-client - local first_client = vim.lsp.get_clients({ bufnr = 0 })[1] - local offset_encoding = first_client and first_client.offset_encoding or 'utf-16' - - -- completion context with additional info about how it was triggered - local params = vim.lsp.util.make_position_params(nil, offset_encoding) - params.context = { - triggerKind = context.trigger.kind, - } - if context.trigger.kind == vim.lsp.protocol.CompletionTriggerKind.TriggerCharacter then - params.context.triggerCharacter = context.trigger.character - end + -- request from each client individually so slow LSPs don't delay the response + local cancel_fns = {} + for _, client in pairs(clients) do + local params = vim.lsp.util.make_position_params(0, client.offset_encoding) + params.context = { triggerKind = context.trigger.kind } + if context.trigger.kind == CompletionTriggerKind.TriggerCharacter then + params.context.triggerCharacter = context.trigger.character + end - -- special case, the first character of the context is a trigger character, so we adjust the position - -- sent to the LSP server to be the start of the trigger character - -- - -- some LSP do their own filtering before returning results, which we want to avoid - -- since we perform fuzzy matching ourselves. - -- - -- this also avoids having to make multiple calls to the LSP server in case characters are deleted - -- for these special cases - -- i.e. hello.wor| would be sent as hello.|wor - -- TODO: should we still make two calls to the LSP server and merge? - -- TODO: breaks the textEdit resolver since it assumes the request was made from the cursor - -- local trigger_characters = self:get_trigger_characters() - -- local trigger_character_block_list = { ' ', '\n', '\t' } - -- local bounds = context.bounds - -- local trigger_character_before_context = context.line:sub(bounds.start_col - 1, bounds.start_col - 1) - -- if - -- vim.tbl_contains(trigger_characters, trigger_character_before_context) - -- and not vim.tbl_contains(trigger_character_block_list, trigger_character_before_context) - -- then - -- local offset_encoding = vim.lsp.get_clients({ bufnr = 0 })[1].offset_encoding - -- params.position.character = - -- vim.lsp.util.character_offset(0, params.position.line, bounds.start_col - 1, offset_encoding) - -- end - - -- request from each of the clients - -- todo: refactor - return vim.lsp.buf_request_all(0, 'textDocument/completion', params, function(result) - local responses = {} - for client_id, response in pairs(result) do - -- todo: pass error upstream - if response.err or response.result == nil then - responses[client_id] = { is_incomplete_forward = true, is_incomplete_backward = true, items = {} } - - -- as per the spec, we assume it's complete if we get CompletionItem[] - elseif response.result.items == nil then - responses[client_id] = { - is_incomplete_forward = false, - is_incomplete_backward = true, - items = response.result, - } - - -- convert full response to our internal format - else - -- add defaults to the items - local defaults = response.result and response.result.itemDefaults or {} - local known_defaults = { - 'commitCharacters', - 'editRange', - 'insertTextFormat', - 'insertTextMode', - 'data', - } - for _, item in ipairs(response.result.items) do - for key, value in pairs(defaults) do - if vim.tbl_contains(known_defaults, key) then item[key] = item[key] or value end - end - end + local _, request_id = client.request('textDocument/completion', params, function(err, result) + if err or result == nil then return end - responses[client_id] = { - is_incomplete_forward = response.result.isIncomplete, - is_incomplete_backward = true, - items = response.result.items, - } - end - end + local items = result.items or result - -- add client_id and defaults to the items - for client_id, response in pairs(responses) do - for _, item in ipairs(response.items) do + -- add defaults, client id and score offset to the items + for _, item in ipairs(items) do -- todo: terraform lsp doesn't return a .kind in situations like `toset`, is there a default value we need to grab? - -- it doesn't seem to return itemDefaults either + -- it doesn't seem to return itemDefaults item.kind = item.kind or require('blink.cmp.types').CompletionItemKind.Text - item.client_id = client_id + + item.client_id = client.id -- todo: make configurable if item.deprecated or (item.tags and vim.tbl_contains(item.tags, 1)) then item.score_offset = -2 end + + for key, value in pairs(result.itemDefaults or {}) do + if vim.tbl_contains(known_defaults, key) then item[key] = item[key] or value end + end end - end - -- combine responses - -- todo: ideally pass multiple responses to the sources - -- so that we can do fine-grained isIncomplete - -- or do caching here - local combined_response = { is_incomplete_forward = false, is_incomplete_backward = false, items = {} } - for _, response in pairs(responses) do - combined_response.is_incomplete_forward = combined_response.is_incomplete_forward - or response.is_incomplete_forward - combined_response.is_incomplete_backward = combined_response.is_incomplete_backward - or response.is_incomplete_backward - vim.list_extend(combined_response.items, response.items) - end + callback({ + is_incomplete_forward = result.isIncomplete or false, + is_incomplete_backward = true, + items = items, + }) + end) + if request_id ~= nil then cancel_fns[#cancel_fns + 1] = function() client.cancel_request(request_id) end end + end - callback(combined_response) - end) + return function() + for _, cancel_fn in ipairs(cancel_fns) do + cancel_fn() + end + end end --- Resolve --- diff --git a/lua/blink/cmp/sources/snippets/utils.lua b/lua/blink/cmp/sources/snippets/utils.lua index dfae98b1..b323b593 100644 --- a/lua/blink/cmp/sources/snippets/utils.lua +++ b/lua/blink/cmp/sources/snippets/utils.lua @@ -1,4 +1,6 @@ -local utils = {} +local utils = { + parse_cache = {}, +} --- Parses the json file and notifies the user if there's an error ---@param path string @@ -26,8 +28,12 @@ end ---@type fun(input: string): vim.snippet.Node|nil function utils.safe_parse(input) + if utils.parse_cache[input] then return utils.parse_cache[input] end + local safe, parsed = pcall(vim.lsp._snippet_grammar.parse, input) if not safe then return nil end + + utils.parse_cache[input] = parsed return parsed end