diff --git a/lua/sg/cody/rpc.lua b/lua/sg/cody/rpc.lua index a0adc571..d7fd9a58 100644 --- a/lua/sg/cody/rpc.lua +++ b/lua/sg/cody/rpc.lua @@ -40,6 +40,12 @@ end --- Gets the server config ---@return CodyClientInfo local get_server_config = function(creds, remote_url) + -- Add any custom headers from user configuration + local custom_headers = { ["User-Agent"] = "Sourcegraph Cody Neovim Plugin" } + if config.src_headers then + custom_headers = vim.tbl_extend("error", custom_headers, config.src_headers) + end + return { name = "neovim", version = require("sg.private.data").version, @@ -48,7 +54,7 @@ local get_server_config = function(creds, remote_url) accessToken = creds.token, serverEndpoint = creds.endpoint, codebase = remote_url, - customHeaders = { ["User-Agent"] = "Sourcegraph Cody Neovim Plugin" }, + customHeaders = custom_headers, eventProperties = { anonymousUserID = require("sg.private.data").get_cody_data().user, prefix = "CodyNeovimPlugin", diff --git a/lua/sg/config.lua b/lua/sg/config.lua index 4df0d9e5..b65e2b39 100644 --- a/lua/sg/config.lua +++ b/lua/sg/config.lua @@ -24,6 +24,7 @@ ---@field skip_node_check boolean?: Useful if using other js runtime ---@field cody_agent string?: path to the cody-agent js bundle ---@field on_attach function?: function to run when attaching to sg:// buffers +---@field src_headers? table: Headers to be sent with each sg request ---@type sg.config local config = { diff --git a/lua/sg/lsp/init.lua b/lua/sg/lsp/init.lua index a8296557..c2d8058e 100644 --- a/lua/sg/lsp/init.lua +++ b/lua/sg/lsp/init.lua @@ -43,13 +43,20 @@ M.get_client_id = function() return end + local src_headers = require("sg.config").src_headers + if not src_headers or src_headers == "" then + src_headers = nil + end + ---@diagnostic disable-next-line: missing-fields + local headers = require("sg.config").src_headers M._client = vim.lsp.start_client { name = "sourcegraph", cmd = { cmd }, cmd_env = { SRC_ENDPOINT = auth.endpoint, SRC_ACCESS_TOKEN = auth.token, + SRC_HEADERS = src_headers and vim.json.encode(src_headers) or nil, }, handlers = { -- For definitions, we need to preload the buffers so that we don't diff --git a/lua/sg/request.lua b/lua/sg/request.lua index 9425cca4..b9b0dd83 100644 --- a/lua/sg/request.lua +++ b/lua/sg/request.lua @@ -43,6 +43,8 @@ M.start = function(opts) vim.wait(10) end + local src_headers = require("sg.config").src_headers + -- Verify that the environment is properly configured M.client = lsp.start(bin_sg_nvim, {}, { notification = function(method, data) @@ -66,6 +68,7 @@ M.start = function(opts) PATH = vim.env.PATH, SRC_ACCESS_TOKEN = vim.env.SRC_ACCESS_TOKEN, SRC_ENDPOINT = vim.env.SRC_ENDPOINT, + SRC_HEADERS = src_headers and vim.json.encode(src_headers) or nil, }, }) diff --git a/src/bin/sg-lsp.rs b/src/bin/sg-lsp.rs index a62c8e3c..8bec0742 100644 --- a/src/bin/sg-lsp.rs +++ b/src/bin/sg-lsp.rs @@ -4,8 +4,8 @@ use { lsp_server::{Connection, ExtractError, Message, Request, RequestId, Response}, lsp_types::{ request::{GotoDefinition, HoverRequest, References}, - GotoDefinitionParams, GotoDefinitionResponse, Hover, HoverParams, InitializeParams, - ReferenceParams, ServerCapabilities, + GotoDefinitionParams, GotoDefinitionResponse, Hover, HoverParams, ReferenceParams, + ServerCapabilities, }, serde::{Deserialize, Serialize}, }; @@ -172,9 +172,14 @@ async fn handle_sourcegraph_read( Ok(()) } +#[derive(Debug, Default, Deserialize, Serialize)] +struct SgInitOptions { + headers: Option, +} + async fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { - let _params: InitializeParams = serde_json::from_value(params)?; - info!("Starting main loop..."); + // let src_headers = serde_json::valu + info!("Starting main loop: {:?}", params); for msg in &connection.receiver { info!("got msg: {:?}", msg); diff --git a/src/lib.rs b/src/lib.rs index b44d1e6b..d4fa0d3d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ use { anyhow::Result, graphql_client::GraphQLQuery, lsp_types::Location, once_cell::sync::Lazy, regex::Regex, reqwest::Client, sg_gql::dotcom_user::UserInfo, sg_types::*, + std::collections::HashMap, }; pub mod auth; @@ -51,16 +52,38 @@ mod graphql { pub fn get_headers() -> reqwest::header::HeaderMap { use reqwest::header::*; - let mut x = HeaderMap::new(); + let mut header_map = HeaderMap::new(); + + // Add same user agent as we do for Cody requests via node + header_map.insert( + USER_AGENT, + HeaderValue::from_static("Sourcegraph Cody Neovim Plugin"), + ); + + // Auth if let Some(sourcegraph_access_token) = auth::get_access_token() { - x.insert( + header_map.insert( AUTHORIZATION, HeaderValue::from_str(&format!("token {sourcegraph_access_token}")) .expect("to make header"), ); } - x + // If `SRC_HEADERS` is set, append these headers to the request. + if let Ok(src_headers) = std::env::var("SRC_HEADERS") { + if let Ok(headers) = serde_json::from_str::>(&src_headers) { + for (key, value) in headers { + let header = HeaderName::from_bytes(key.as_bytes()); + let value = HeaderValue::from_str(&value); + + if let (Ok(header), Ok(value)) = (header, value) { + header_map.insert(header, value); + }; + } + } + } + + header_map } macro_rules! wrap_request {