Skip to content

refactor!: centralize sticky prompt and selection handling #855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -453,11 +453,6 @@ Below are all available configuration options with their default values:
headless = false, -- Do not write to chat buffer and use history(useful for using callback for custom processing)
callback = nil, -- Callback to use when ask response is received

-- default selection
selection = function(source)
return select.visual(source) or select.buffer(source)
end,

-- default window options
window = {
layout = 'vertical', -- 'vertical', 'horizontal', 'float', 'replace'
Expand Down Expand Up @@ -499,6 +494,12 @@ Below are all available configuration options with their default values:
error_header = '# Error ', -- Header to use for errors
separator = '───', -- Separator to use in chat

-- default selection
-- see config/select.lua for implementation
selection = function(source)
return select.visual(source) or select.buffer(source)
end,

-- default providers
-- see config/providers.lua for implementation
providers = {
Expand Down Expand Up @@ -535,10 +536,12 @@ Below are all available configuration options with their default values:
-- see config/prompts.lua for implementation
prompts = {
Explain = {
prompt = '> /COPILOT_EXPLAIN\n\nWrite an explanation for the selected code as paragraphs of text.',
prompt = 'Write an explanation for the selected code as paragraphs of text.',
sticky = '/COPILOT_EXPLAIN',
},
Review = {
prompt = '> /COPILOT_REVIEW\n\nReview the selected code.',
prompt = 'Review the selected code.',
sticky = '/COPILOT_REVIEW',
},
Fix = {
prompt = 'There is a problem in this code. Identify the issues and rewrite the code with fixes. Explain what was wrong and how your changes address the problems.',
Expand All @@ -553,7 +556,8 @@ Below are all available configuration options with their default values:
prompt = 'Please generate tests for my code.',
},
Commit = {
prompt = '> #git:staged\n\nWrite commit message for the change with commitizen convention. Keep the title under 50 characters and wrap message at 72 characters. Format as a gitcommit code block.',
prompt = 'Write commit message for the change with commitizen convention. Keep the title under 50 characters and wrap message at 72 characters. Format as a gitcommit code block.',
sticky = '#git:staged',
},
},

Expand Down
1 change: 0 additions & 1 deletion lua/CopilotChat/client.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
---@class CopilotChat.Client.agent : CopilotChat.Provider.agent
---@field provider string

local async = require('plenary.async')
local log = require('plenary.log')
local tiktoken = require('CopilotChat.tiktoken')
local notify = require('CopilotChat.notify')
Expand Down
12 changes: 6 additions & 6 deletions lua/CopilotChat/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ local select = require('CopilotChat.select')
---@field temperature number?
---@field headless boolean?
---@field callback fun(response: string, source: CopilotChat.source)?
---@field selection false|nil|fun(source: CopilotChat.source):CopilotChat.select.selection?
---@field window CopilotChat.config.window?
---@field show_help boolean?
---@field show_folds boolean?
Expand All @@ -47,6 +46,7 @@ local select = require('CopilotChat.select')
---@field answer_header string?
---@field error_header string?
---@field separator string?
---@field selection false|nil|fun(source: CopilotChat.source):CopilotChat.select.selection?
---@field providers table<string, CopilotChat.Provider>?
---@field contexts table<string, CopilotChat.config.context>?
---@field prompts table<string, CopilotChat.config.prompt|string>?
Expand All @@ -66,11 +66,6 @@ return {
headless = false, -- Do not write to chat buffer and use history(useful for using callback for custom processing)
callback = nil, -- Callback to use when ask response is received

-- default selection
selection = function(source)
return select.visual(source) or select.buffer(source)
end,

-- default window options
window = {
layout = 'vertical', -- 'vertical', 'horizontal', 'float', 'replace'
Expand Down Expand Up @@ -113,6 +108,11 @@ return {
error_header = '## Error ', -- Header to use for errors
separator = '───', -- Separator to use in chat

-- default selection
selection = function(source)
return select.visual(source) or select.buffer(source)
end,

-- default providers
providers = require('CopilotChat.config.providers'),

Expand Down
51 changes: 13 additions & 38 deletions lua/CopilotChat/config/mappings.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ local utils = require('CopilotChat.utils')
---@param chat CopilotChat.ui.Chat
---@return CopilotChat.ui.Diff.Diff?
local function get_diff(chat)
local config = chat.config
local block = chat:get_closest_block()

-- If no block found, return nil
Expand All @@ -15,7 +14,7 @@ local function get_diff(chat)

-- Initialize variables with selection if available
local header = block.header
local selection = copilot.get_selection(config)
local selection = copilot.get_selection()
local reference = selection and selection.content
local start_line = selection and selection.start_line
local end_line = selection and selection.end_line
Expand Down Expand Up @@ -64,35 +63,15 @@ local function get_diff(chat)
}
end

---@param winnr number
---@param bufnr number
---@param start_line number
---@param end_line number
---@param config CopilotChat.config.shared
local function jump_to_diff(winnr, bufnr, start_line, end_line, config)
pcall(vim.api.nvim_buf_set_mark, bufnr, '<', start_line, 0, {})
pcall(vim.api.nvim_buf_set_mark, bufnr, '>', end_line, 0, {})
pcall(vim.api.nvim_buf_set_mark, bufnr, '[', start_line, 0, {})
pcall(vim.api.nvim_buf_set_mark, bufnr, ']', end_line, 0, {})
pcall(vim.api.nvim_win_set_cursor, winnr, { start_line, 0 })
copilot.update_selection(config)
end

---@param diff CopilotChat.ui.Diff.Diff?
---@param config CopilotChat.config.shared
local function apply_diff(diff, config)
local function apply_diff(diff)
if not diff or not diff.bufnr then
return
end

local winnr = vim.fn.win_findbuf(diff.bufnr)[1]
if not winnr then
return
end

local lines = vim.split(diff.change, '\n', { trimempty = false })
vim.api.nvim_buf_set_lines(diff.bufnr, diff.start_line - 1, diff.end_line, false, lines)
jump_to_diff(winnr, diff.bufnr, diff.start_line, diff.start_line + #lines - 1, config)
copilot.set_selection(diff.bufnr, diff.start_line, diff.start_line + #lines - 1)
end

---@class CopilotChat.config.mapping
Expand Down Expand Up @@ -208,7 +187,7 @@ return {
normal = '<C-y>',
insert = '<C-y>',
callback = function(overlay, diff, chat, source)
apply_diff(get_diff(chat), chat.config)
apply_diff(get_diff(chat))
end,
},

Expand All @@ -234,8 +213,7 @@ return {

source.bufnr = diff_bufnr
vim.api.nvim_win_set_buf(source.winnr, diff_bufnr)

jump_to_diff(source.winnr, diff_bufnr, diff.start_line, diff.end_line, chat.config)
copilot.set_selection(diff_bufnr, diff.start_line, diff.end_line)
end,
},

Expand Down Expand Up @@ -278,7 +256,7 @@ return {
quickfix_diffs = {
normal = 'gqd',
callback = function(overlay, diff, chat)
local selection = copilot.get_selection(chat.config)
local selection = copilot.get_selection()
local items = {}

for _, section in ipairs(chat.sections) do
Expand Down Expand Up @@ -345,16 +323,16 @@ return {
end

local lines = {}
local prompt, config = copilot.resolve_prompts(section.content, chat.config)
local config, prompt = copilot.resolve_prompt(section.content)
local system_prompt = config.system_prompt

async.run(function()
local _, selected_agent = pcall(copilot.resolve_agent, prompt, config)
local _, selected_model = pcall(copilot.resolve_model, prompt, config)
local selected_agent = copilot.resolve_agent(prompt, config)
local selected_model = copilot.resolve_model(prompt, config)

utils.schedule_main()
table.insert(lines, '**Logs**: `' .. chat.config.log_path .. '`')
table.insert(lines, '**History**: `' .. chat.config.history_path .. '`')
table.insert(lines, '**Logs**: `' .. copilot.config.log_path .. '`')
table.insert(lines, '**History**: `' .. copilot.config.history_path .. '`')
table.insert(lines, '**Temp Files**: `' .. vim.fn.fnamemodify(os.tmpname(), ':h') .. '`')
table.insert(lines, '')

Expand Down Expand Up @@ -393,7 +371,7 @@ return {

local lines = {}

local selection = copilot.get_selection(chat.config)
local selection = copilot.get_selection()
if selection then
table.insert(lines, '**Selection**')
table.insert(lines, '```' .. selection.filetype)
Expand All @@ -405,10 +383,7 @@ return {
end

async.run(function()
local embeddings = {}
if section and not section.answer then
embeddings = copilot.resolve_embeddings(section.content, chat.config)
end
local embeddings = copilot.resolve_context(section.content)

for _, embedding in ipairs(embeddings) do
local embed_lines = vim.split(embedding.content, '\n')
Expand Down
9 changes: 6 additions & 3 deletions lua/CopilotChat/config/prompts.lua
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,13 @@ return {
},

Explain = {
prompt = '> /COPILOT_EXPLAIN\n\nWrite an explanation for the selected code as paragraphs of text.',
prompt = 'Write an explanation for the selected code as paragraphs of text.',
sticky = '/COPILOT_EXPLAIN',
},

Review = {
prompt = '> /COPILOT_REVIEW\n\nReview the selected code.',
prompt = 'Review the selected code.',
sticky = '/COPILOT_REVIEW',
callback = function(response, source)
local diagnostics = {}
for line in response:gmatch('[^\r\n]+') do
Expand Down Expand Up @@ -159,6 +161,7 @@ return {
},

Commit = {
prompt = '> #git:staged\n\nWrite commit message for the change with commitizen convention. Keep the title under 50 characters and wrap message at 72 characters. Format as a gitcommit code block.',
prompt = 'Write commit message for the change with commitizen convention. Keep the title under 50 characters and wrap message at 72 characters. Format as a gitcommit code block.',
sticky = '#git:staged',
},
}
Loading
Loading