Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 174 additions & 1 deletion crates/openshell-policy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::path::Path;

use miette::{IntoDiagnostic, Result, WrapErr};
use openshell_core::proto::{
FilesystemPolicy, L7Allow, L7Rule, LandlockPolicy, NetworkBinary, NetworkEndpoint,
CorsConfig, FilesystemPolicy, L7Allow, L7Rule, LandlockPolicy, NetworkBinary, NetworkEndpoint,
NetworkPolicyRule, ProcessPolicy, SandboxPolicy,
};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -99,6 +99,19 @@ struct NetworkEndpointDef {
rules: Vec<L7RuleDef>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
allowed_ips: Vec<String>,
/// CORS configuration for port-forwarded services on this endpoint.
#[serde(default, skip_serializing_if = "Option::is_none")]
cors: Option<CorsConfigDef>,
}

/// CORS configuration for port-forwarded sandbox services.
#[derive(Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
struct CorsConfigDef {
/// Allowed origin URLs (e.g. "https://app.example.com").
/// Use "*" to allow all origins.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
allowed_origins: Vec<String>,
}

fn is_zero(v: &u32) -> bool {
Expand Down Expand Up @@ -180,6 +193,9 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy {
})
.collect(),
allowed_ips: e.allowed_ips,
cors: e.cors.map(|c| CorsConfig {
allowed_origins: c.allowed_origins,
}),
}
})
.collect(),
Expand Down Expand Up @@ -280,6 +296,9 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile {
})
.collect(),
allowed_ips: e.allowed_ips.clone(),
cors: e.cors.as_ref().map(|c| CorsConfigDef {
allowed_origins: c.allowed_origins.clone(),
}),
}
})
.collect(),
Expand Down Expand Up @@ -443,6 +462,8 @@ pub enum PolicyViolation {
FieldTooLong { path: String, length: usize },
/// Too many filesystem paths in the policy.
TooManyPaths { count: usize },
/// CORS configuration has no `allowed_origins`.
CorsEmptyOrigins { endpoint: String },
}

impl fmt::Display for PolicyViolation {
Expand Down Expand Up @@ -472,6 +493,12 @@ impl fmt::Display for PolicyViolation {
"too many filesystem paths ({count} > {MAX_FILESYSTEM_PATHS})"
)
}
Self::CorsEmptyOrigins { endpoint } => {
write!(
f,
"CORS config on endpoint '{endpoint}' has no allowed_origins"
)
}
}
}
}
Expand Down Expand Up @@ -558,6 +585,24 @@ pub fn validate_sandbox_policy(
}
}

// Check CORS configurations in network endpoints
for (key, rule) in &policy.network_policies {
for ep in &rule.endpoints {
if let Some(ref cors) = ep.cors {
let ep_label = if ep.host.is_empty() {
format!("{key}:port-{}", ep.port)
} else {
format!("{key}:{}", ep.host)
};
if cors.allowed_origins.is_empty() {
violations.push(PolicyViolation::CorsEmptyOrigins {
endpoint: ep_label.clone(),
});
}
}
}
}

if violations.is_empty() {
Ok(())
} else {
Expand Down Expand Up @@ -1117,4 +1162,132 @@ network_policies:
proto2.network_policies["test"].endpoints[0].host
);
}

// ---- CORS configuration tests ----

#[test]
fn parse_cors_config() {
let yaml = r#"
version: 1
network_policies:
web:
name: web
endpoints:
- host: localhost
port: 8080
cors:
allowed_origins:
- "https://app.example.com"
- "https://dashboard.example.com"
binaries:
- path: /usr/bin/node
"#;
let policy = parse_sandbox_policy(yaml).expect("should parse");
let ep = &policy.network_policies["web"].endpoints[0];
let cors = ep.cors.as_ref().expect("cors should be present");
assert_eq!(
cors.allowed_origins,
vec!["https://app.example.com", "https://dashboard.example.com"]
);
}

#[test]
fn round_trip_preserves_cors_config() {
let yaml = r#"
version: 1
network_policies:
web:
name: web
endpoints:
- host: localhost
port: 8080
cors:
allowed_origins:
- "https://app.example.com"
binaries:
- path: /usr/bin/node
"#;
let proto1 = parse_sandbox_policy(yaml).expect("parse failed");
let yaml_out = serialize_sandbox_policy(&proto1).expect("serialize failed");
let proto2 = parse_sandbox_policy(&yaml_out).expect("re-parse failed");

let cors1 = proto1.network_policies["web"].endpoints[0]
.cors
.as_ref()
.expect("cors");
let cors2 = proto2.network_policies["web"].endpoints[0]
.cors
.as_ref()
.expect("cors");
assert_eq!(cors1.allowed_origins, cors2.allowed_origins);
}

#[test]
fn parse_endpoint_without_cors() {
let yaml = r#"
version: 1
network_policies:
test:
name: test
endpoints:
- host: api.example.com
port: 443
binaries:
- path: /usr/bin/curl
"#;
let policy = parse_sandbox_policy(yaml).expect("should parse");
assert!(
policy.network_policies["test"].endpoints[0].cors.is_none(),
"cors should be None when not specified"
);
}

#[test]
fn validate_cors_empty_origins() {
let mut policy = restrictive_default_policy();
policy.network_policies.insert(
"web".to_string(),
NetworkPolicyRule {
name: "web".to_string(),
endpoints: vec![NetworkEndpoint {
host: "localhost".to_string(),
port: 8080,
ports: vec![8080],
cors: Some(CorsConfig {
allowed_origins: vec![],
}),
..Default::default()
}],
binaries: vec![],
},
);
let violations = validate_sandbox_policy(&policy).unwrap_err();
assert!(
violations
.iter()
.any(|v| matches!(v, PolicyViolation::CorsEmptyOrigins { .. }))
);
}

#[test]
fn validate_cors_valid_config() {
let mut policy = restrictive_default_policy();
policy.network_policies.insert(
"web".to_string(),
NetworkPolicyRule {
name: "web".to_string(),
endpoints: vec![NetworkEndpoint {
host: "localhost".to_string(),
port: 8080,
ports: vec![8080],
cors: Some(CorsConfig {
allowed_origins: vec!["https://app.example.com".to_string()],
}),
..Default::default()
}],
binaries: vec![],
},
);
assert!(validate_sandbox_policy(&policy).is_ok());
}
}
Loading
Loading