Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions src/cli.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use std::env;
use std::error::Error;
use std::path::PathBuf;

use env_logger::Builder as LoggerBuilder;

#[derive(Debug)]
pub enum CliAction {
Run { log_level: Option<String> },
Run {
log_level: Option<String>,
workdir: Option<PathBuf>,
restrict_to_workdir: bool,
},
Help,
Version,
}
Expand All @@ -15,6 +20,8 @@ where
I: Iterator<Item = String>,
{
let mut log_level = None;
let mut workdir = None;
let mut restrict_to_workdir = false;
let mut iter = args.peekable();

while let Some(arg) = iter.next() {
Expand All @@ -29,18 +36,34 @@ where
.next()
.ok_or_else(|| "--log-level requires a value".to_string())?;
log_level = Some(value);
} else if let Some(path) = arg.strip_prefix("--workdir=") {
if path.is_empty() {
return Err("--workdir requires a value".to_string());
}
workdir = Some(PathBuf::from(path));
} else if arg == "--workdir" {
let value = iter
.next()
.ok_or_else(|| "--workdir requires a value".to_string())?;
workdir = Some(PathBuf::from(value));
} else if arg == "--restrict-to-workdir" {
restrict_to_workdir = true;
} else {
return Err(format!("Unknown argument: {arg}"));
}
}
}
}

Ok(CliAction::Run { log_level })
Ok(CliAction::Run {
log_level,
workdir,
restrict_to_workdir,
})
}

pub fn print_usage() {
println!("Usage: codex-tools-mcp [OPTIONS]\n\nOptions:\n --log-level <level> Override default log level (info)\n -V, --version Print version information\n -h, --help Print this help message");
println!("Usage: codex-tools-mcp [OPTIONS]\n\nOptions:\n --log-level <level> Override default log level (info)\n --workdir <path> Set process working directory before serving\n --restrict-to-workdir Reject apply_patch paths that escape the working directory\n -V, --version Print version information\n -h, --help Print this help message");
}

pub fn init_logging(log_level: Option<String>) -> Result<(), Box<dyn Error>> {
Expand Down
37 changes: 35 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod tools;
use cli::{init_logging, parse_cli, print_usage, CliAction};
use std::env;
use std::error::Error;
use std::io;
use std::process;

fn main() {
Expand All @@ -26,9 +27,41 @@ fn try_main() -> Result<(), Box<dyn Error>> {
println!("{}", cli::version_string());
Ok(())
}
CliAction::Run { log_level } => {
CliAction::Run {
log_level,
workdir,
restrict_to_workdir,
} => {
if let Some(workdir) = workdir {
env::set_current_dir(&workdir).map_err(|err| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"Failed to set working directory to {}: {err}",
workdir.display()
),
)
})?;
}

let restrict_root = if restrict_to_workdir {
let cwd = env::current_dir().map_err(|err| {
io::Error::other(format!(
"Failed to resolve current working directory: {err}"
))
})?;
Some(cwd.canonicalize().map_err(|err| {
io::Error::other(format!(
"Failed to canonicalize working directory {}: {err}",
cwd.display()
))
})?)
} else {
None
};

init_logging(log_level)?;
server::run_server()?;
server::run_server(server::ServerConfig { restrict_root })?;
Ok(())
}
}
Expand Down
162 changes: 156 additions & 6 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
use codex_apply_patch::apply_patch as run_apply_patch;
use log::{debug, error, info, warn};
use serde_json::{json, Value};
use std::fs;
use std::io::{self, BufRead, Write};
use std::path::{Component, Path, PathBuf};

use crate::tools::{
apply_patch_tool_schema, update_plan_tool_schema, INVALID_PARAMS, INVALID_REQUEST,
JSONRPC_VERSION, MCP_PROTOCOL_VERSION, METHOD_NOT_FOUND, PARSE_ERROR,
};

pub fn run_server() -> io::Result<()> {
pub struct ServerConfig {
pub restrict_root: Option<PathBuf>,
}

pub fn run_server(config: ServerConfig) -> io::Result<()> {
let stdin = io::stdin();
for line_result in stdin.lock().lines() {
let line = match line_result {
Expand Down Expand Up @@ -46,15 +52,15 @@ pub fn run_server() -> io::Result<()> {
continue;
}

if let Err(err) = handle_message(message) {
if let Err(err) = handle_message(message, &config) {
error!("internal error while processing message: {err}");
}
}

Ok(())
}

fn handle_message(message: Value) -> io::Result<()> {
fn handle_message(message: Value, config: &ServerConfig) -> io::Result<()> {
let method = message
.get("method")
.and_then(Value::as_str)
Expand All @@ -67,7 +73,7 @@ fn handle_message(message: Value) -> io::Result<()> {
match method {
"initialize" => handle_initialize(request_id, params),
"tools/list" => handle_tools_list(request_id),
"tools/call" => handle_tools_call(request_id, params),
"tools/call" => handle_tools_call(request_id, params, config),
"ping" => handle_ping(request_id),
_ => send_error(
request_id,
Expand Down Expand Up @@ -122,7 +128,11 @@ fn handle_tools_list(request_id: Option<Value>) -> io::Result<()> {
send_result(request_id, result)
}

fn handle_tools_call(request_id: Option<Value>, params: Option<Value>) -> io::Result<()> {
fn handle_tools_call(
request_id: Option<Value>,
params: Option<Value>,
config: &ServerConfig,
) -> io::Result<()> {
if request_id.is_none() {
return send_error(None, INVALID_REQUEST, "tools/call must include an id");
}
Expand Down Expand Up @@ -151,7 +161,7 @@ fn handle_tools_call(request_id: Option<Value>, params: Option<Value>) -> io::Re
});
send_result(request_id, result)
}
Some("apply_patch") => handle_apply_patch_tool(request_id, &params_obj),
Some("apply_patch") => handle_apply_patch_tool(request_id, &params_obj, config),
Some(other) => {
warn!("unknown tool requested: {other}");
send_error(
Expand All @@ -167,6 +177,7 @@ fn handle_tools_call(request_id: Option<Value>, params: Option<Value>) -> io::Re
fn handle_apply_patch_tool(
request_id: Option<Value>,
params_obj: &serde_json::Map<String, Value>,
config: &ServerConfig,
) -> io::Result<()> {
let arguments = match params_obj.get("arguments") {
Some(Value::Object(arguments)) => arguments,
Expand All @@ -191,6 +202,13 @@ fn handle_apply_patch_tool(
}
};

if let Some(root) = &config.restrict_root {
if let Err(err) = validate_patch_paths_within_root(&patch, root) {
warn!("apply_patch blocked by --restrict-to-workdir: {err}");
return send_apply_patch_error(request_id, err);
}
}

info!("running apply_patch ({} bytes)", patch.len());

let mut stdout_buf = Vec::new();
Expand Down Expand Up @@ -256,6 +274,138 @@ fn handle_apply_patch_tool(
}
}

fn validate_patch_paths_within_root(patch: &str, root: &Path) -> Result<(), String> {
const FILE_PREFIXES: [(&str, PatchPathKind); 4] = [
("*** Add File: ", PatchPathKind::WriteLike),
("*** Update File: ", PatchPathKind::WriteLike),
("*** Delete File: ", PatchPathKind::Delete),
("*** Move to: ", PatchPathKind::WriteLike),
];

for (idx, line) in patch.lines().enumerate() {
for (prefix, kind) in FILE_PREFIXES {
if let Some(path_text) = line.strip_prefix(prefix) {
validate_patch_path(path_text, root, kind)
.map_err(|err| format!("{err} (line {})", idx + 1))?;
}
}
}

Ok(())
}

#[derive(Clone, Copy)]
enum PatchPathKind {
WriteLike,
Delete,
}

fn validate_patch_path(path_text: &str, root: &Path, kind: PatchPathKind) -> Result<(), String> {
if path_text.is_empty() {
return Err("Patch path cannot be empty".to_string());
}

let path = Path::new(path_text);
if path.is_absolute() {
return Err(format!("Absolute patch paths are not allowed: {path_text}"));
}

let normalized = normalize_relative_path(path)
.ok_or_else(|| format!("Patch path escapes workdir: {path_text}"))?;
if normalized.as_os_str().is_empty() {
return Err(format!("Patch path cannot resolve to workdir root: {path_text}"));
}

let allow_terminal_symlink_delete = matches!(kind, PatchPathKind::Delete);
ensure_path_stays_within_root(root, &normalized, allow_terminal_symlink_delete)
.map_err(|err| format!("{err}: {path_text}"))?;

Ok(())
}

fn normalize_relative_path(path: &Path) -> Option<PathBuf> {
let mut normalized = PathBuf::new();
for component in path.components() {
match component {
Component::CurDir => {}
Component::Normal(seg) => normalized.push(seg),
Component::ParentDir => {
if !normalized.pop() {
return None;
}
}
_ => return None,
}
}
Some(normalized)
}

fn ensure_path_stays_within_root(
root: &Path,
relative: &Path,
allow_terminal_symlink_delete: bool,
) -> Result<(), String> {
let mut cursor = root.to_path_buf();
let mut components = relative.components().peekable();
while let Some(component) = components.next() {
let is_last = components.peek().is_none();
let segment = match component {
Component::Normal(seg) => seg,
_ => return Err("Patch path contains unsupported component".to_string()),
};

cursor.push(segment);

match fs::symlink_metadata(&cursor) {
Ok(metadata) => match fs::canonicalize(&cursor) {
Ok(canonical) => {
if allow_terminal_symlink_delete && is_last && metadata.file_type().is_symlink()
{
// Deleting the symlink entry itself is safe, regardless of where it points.
continue;
}
if !canonical.starts_with(root) {
return Err("Patch path escapes workdir via symlink".to_string());
}
cursor = canonical;
}
Err(err)
if allow_terminal_symlink_delete
&& is_last
&& err.kind() == io::ErrorKind::NotFound
&& metadata.file_type().is_symlink() =>
{
// Allow deleting a broken symlink that is inside workdir.
}
Err(err) => return Err(format!("Failed to resolve patch path: {err}")),
},
Err(err) if err.kind() == io::ErrorKind::NotFound => {
// Remaining components do not exist yet, so lexical relative joining is enough.
}
Err(err) => return Err(format!("Failed to inspect patch path: {err}")),
}
}

if !cursor.starts_with(root) {
return Err("Patch path escapes workdir".to_string());
}

Ok(())
}

fn send_apply_patch_error(request_id: Option<Value>, message: impl Into<String>) -> io::Result<()> {
let result = json!({
"content": [
{
"type": "text",
"text": format!("apply_patch failed: {}", message.into()),
}
],
"isError": true,
});
send_result(request_id, result)
}

fn handle_ping(request_id: Option<Value>) -> io::Result<()> {
if request_id.is_none() {
return send_error(None, INVALID_REQUEST, "ping must include an id");
Expand Down
Loading