Skip to content

Commit 33c350e

Browse files
committed
Simplify how source buffer is being saved and unify prompt clearing
- Instead of passing state.source around, simply update it only if we are outside of chat window and make M.open() noop otherwise - Unify when last prompt is cleared, clear it in .ask always instead of doing it only when explicitely submitting prompt - Unify how selection is checked when simply showing it Signed-off-by: Tomas Slusny <slusnucky@gmail.com>
1 parent c3518e5 commit 33c350e

File tree

1 file changed

+61
-36
lines changed

1 file changed

+61
-36
lines changed

lua/CopilotChat/init.lua

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ end
6262

6363
local function find_lines_between_separator(
6464
lines,
65-
start_line,
65+
current_line,
6666
start_pattern,
6767
end_pattern,
6868
allow_end_of_file
@@ -72,7 +72,6 @@ local function find_lines_between_separator(
7272
end
7373

7474
local line_count = #lines
75-
local current_line = vim.api.nvim_win_get_cursor(0)[1] - start_line
7675
local separator_line_start = 1
7776
local separator_line_finish = line_count
7877
local found_one = false
@@ -385,14 +384,20 @@ end
385384

386385
--- Open the chat window.
387386
---@param config CopilotChat.config|CopilotChat.config.prompt|nil
388-
---@param source CopilotChat.config.source?
389-
function M.open(config, source)
387+
function M.open(config)
388+
-- If we are already in chat window, do nothing
389+
if state.chat:active() then
390+
return
391+
end
392+
390393
config = vim.tbl_deep_extend('force', M.config, config or {})
391394
state.config = config
392-
state.source = vim.tbl_extend('keep', source or {}, {
395+
396+
-- Save the source buffer and window (e.g the buffer we are currently asking about)
397+
state.source = {
393398
bufnr = vim.api.nvim_get_current_buf(),
394399
winnr = vim.api.nvim_get_current_win(),
395-
})
400+
}
396401

397402
utils.return_to_normal_mode()
398403
state.chat:open(config)
@@ -407,12 +412,11 @@ end
407412

408413
--- Toggle the chat window.
409414
---@param config CopilotChat.config|nil
410-
---@param source CopilotChat.config.source?
411-
function M.toggle(config, source)
415+
function M.toggle(config)
412416
if state.chat:visible() then
413417
M.close()
414418
else
415-
M.open(config, source)
419+
M.open(config)
416420
end
417421
end
418422

@@ -472,11 +476,10 @@ end
472476
--- Ask a question to the Copilot model.
473477
---@param prompt string
474478
---@param config CopilotChat.config|CopilotChat.config.prompt|nil
475-
---@param source CopilotChat.config.source?
476-
function M.ask(prompt, config, source)
479+
function M.ask(prompt, config)
477480
config = vim.tbl_deep_extend('force', M.config, config or {})
478481
vim.diagnostic.reset(vim.api.nvim_create_namespace('copilot_diagnostics'))
479-
M.open(config, source)
482+
M.open(config)
480483

481484
prompt = vim.trim(prompt or '')
482485
if prompt == '' then
@@ -489,6 +492,14 @@ function M.ask(prompt, config, source)
489492
finish(config, nil, true)
490493
end
491494

495+
-- Clear the current input prompt before asking a new question
496+
local chat_lines = vim.api.nvim_buf_get_lines(state.chat.bufnr, 0, -1, false)
497+
local _, start_line, end_line =
498+
find_lines_between_separator(chat_lines, #chat_lines, M.config.separator .. '$', nil, true)
499+
if #chat_lines == end_line then
500+
vim.api.nvim_buf_set_lines(state.chat.bufnr, start_line, end_line, false, { '' })
501+
end
502+
492503
state.chat:append(prompt)
493504
state.chat:append('\n\n' .. config.answer_header .. config.separator .. '\n\n')
494505

@@ -884,17 +895,15 @@ function M.setup(config)
884895

885896
map_key(M.config.mappings.submit_prompt, bufnr, function()
886897
local chat_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
887-
local lines, start_line, end_line =
888-
find_lines_between_separator(chat_lines, 0, M.config.separator .. '$', nil, true)
889-
local input = vim.trim(table.concat(lines, '\n'))
890-
if input ~= '' then
891-
-- If we are entering the input at the end, replace it
892-
if #chat_lines == end_line then
893-
vim.api.nvim_buf_set_lines(bufnr, start_line, end_line, false, { '' })
894-
end
895-
896-
M.ask(input, state.config, state.source)
897-
end
898+
local current_line = vim.api.nvim_win_get_cursor(0)[1]
899+
local lines = find_lines_between_separator(
900+
chat_lines,
901+
current_line,
902+
M.config.separator .. '$',
903+
nil,
904+
true
905+
)
906+
M.ask(vim.trim(table.concat(lines, '\n')), state.config)
898907
end)
899908

900909
map_key(M.config.mappings.toggle_sticky, bufnr, function()
@@ -909,7 +918,7 @@ function M.setup(config)
909918

910919
local chat_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
911920
local _, start_line, end_line =
912-
find_lines_between_separator(chat_lines, 0, M.config.separator .. '$', nil, true)
921+
find_lines_between_separator(chat_lines, cur_line, M.config.separator .. '$', nil, true)
913922

914923
if vim.startswith(current_line, '> ') then
915924
return
@@ -942,15 +951,20 @@ function M.setup(config)
942951

943952
map_key(M.config.mappings.accept_diff, bufnr, function()
944953
local selection = get_selection()
945-
if not selection.start_row or not selection.end_row then
954+
if not selection or not selection.start_row or not selection.end_row then
946955
return
947956
end
948957

949958
local chat_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
959+
local current_line = vim.api.nvim_win_get_cursor(0)[1]
950960
local section_lines, start_line =
951-
find_lines_between_separator(chat_lines, 0, M.config.separator .. '$')
952-
local lines =
953-
find_lines_between_separator(section_lines, start_line - 1, '^```%w+$', '^```$')
961+
find_lines_between_separator(chat_lines, current_line, M.config.separator .. '$')
962+
local lines = find_lines_between_separator(
963+
section_lines,
964+
current_line - start_line - 1,
965+
'^```%w+$',
966+
'^```$'
967+
)
954968
if #lines > 0 then
955969
vim.api.nvim_buf_set_text(
956970
state.source.bufnr,
@@ -965,15 +979,20 @@ function M.setup(config)
965979

966980
map_key(M.config.mappings.yank_diff, bufnr, function()
967981
local selection = get_selection()
968-
if not selection.start_row or not selection.end_row then
982+
if not selection or not selection.lines then
969983
return
970984
end
971985

972986
local chat_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
987+
local current_line = vim.api.nvim_win_get_cursor(0)[1]
973988
local section_lines, start_line =
974-
find_lines_between_separator(chat_lines, 0, M.config.separator .. '$')
975-
local lines =
976-
find_lines_between_separator(section_lines, start_line - 1, '^```%w+$', '^```$')
989+
find_lines_between_separator(chat_lines, current_line, M.config.separator .. '$')
990+
local lines = find_lines_between_separator(
991+
section_lines,
992+
current_line - start_line - 1,
993+
'^```%w+$',
994+
'^```$'
995+
)
977996
if #lines > 0 then
978997
local content = table.concat(lines, '\n')
979998
vim.fn.setreg(M.config.mappings.yank_diff.register, content)
@@ -982,15 +1001,21 @@ function M.setup(config)
9821001

9831002
map_key(M.config.mappings.show_diff, bufnr, function()
9841003
local selection = get_selection()
985-
if not selection or not selection.start_row or not selection.end_row then
1004+
if not selection or not selection.lines then
9861005
return
9871006
end
9881007

9891008
local chat_lines = vim.api.nvim_buf_get_lines(state.chat.bufnr, 0, -1, false)
1009+
local current_line = vim.api.nvim_win_get_cursor(0)[1]
9901010
local section_lines, start_line =
991-
find_lines_between_separator(chat_lines, 0, M.config.separator .. '$')
1011+
find_lines_between_separator(chat_lines, current_line, M.config.separator .. '$')
9921012
local lines = table.concat(
993-
find_lines_between_separator(section_lines, start_line - 1, '^```%w+$', '^```$'),
1013+
find_lines_between_separator(
1014+
section_lines,
1015+
current_line - start_line - 1,
1016+
'^```%w+$',
1017+
'^```$'
1018+
),
9941019
'\n'
9951020
)
9961021
if vim.trim(lines) ~= '' then
@@ -1026,7 +1051,7 @@ function M.setup(config)
10261051

10271052
map_key(M.config.mappings.show_user_selection, bufnr, function()
10281053
local selection = get_selection()
1029-
if not selection.start_row or not selection.end_row then
1054+
if not selection or not selection.lines then
10301055
return
10311056
end
10321057

0 commit comments

Comments
 (0)