Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type (
Logging LoggingConfig `yaml:"logging"`
LogConfigs map[string]LoggingConfig `yaml:"logConfigs"`
ClusterConnections []ClusterConnConfig `yaml:"clusterConnections"`
Encryption Encryption `yaml:"encryption"`
}

SATranslationConfig struct {
Expand Down Expand Up @@ -165,7 +166,7 @@ func WriteConfig[T any](config T, filePath string) error {
}

// Write the YAML to a file
err = os.WriteFile(filePath, data, 0644)
err = os.WriteFile(filePath, data, 0o644)
if err != nil {
return err
}
Expand All @@ -179,9 +180,7 @@ type (
}
)

var (
EmptyConfigProvider MockConfigProvider
)
var EmptyConfigProvider MockConfigProvider

func NewMockConfigProvider(config S2SProxyConfig) *MockConfigProvider {
return &MockConfigProvider{config: config}
Expand Down Expand Up @@ -245,6 +244,7 @@ func (s SearchAttributeTranslation) Inverse() SearchAttributeTranslation {
inverted: !s.inverted,
}
}

func (s SearchAttributeTranslation) withInvert(namespace string) (collect.StaticBiMap[string, string], bool) {
m, found := s.inner[namespace]
if !found {
Expand All @@ -255,27 +255,32 @@ func (s SearchAttributeTranslation) withInvert(namespace string) (collect.Static
}
return m, true
}

func (s SearchAttributeTranslation) Get(namespace string, searchAttr string) string {
if m, ok := s.withInvert(namespace); ok {
return m.Get(searchAttr)
}
return ""
}

func (s SearchAttributeTranslation) GetExists(namespace string, searchAttr string) (string, bool) {
if m, ok := s.withInvert(namespace); ok {
return m.GetExists(searchAttr)
}
return "", false
}

func (s SearchAttributeTranslation) LenNamespaces() int {
return len(s.inner)
}

func (s SearchAttributeTranslation) Len(namespace string) int {
if m, ok := s.withInvert(namespace); ok {
return m.Len()
}
return 0
}

func (s SearchAttributeTranslation) FlattenMaps() map[string]map[string]string {
raw := make(map[string]map[string]string, len(s.inner))
for ns, mappings := range s.inner {
Expand Down
94 changes: 94 additions & 0 deletions config/encryption.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package config

import (
"fmt"
"net/url"
"slices"
"strings"
"time"
)

// The set of valid KMS key schemes
var validKeySchemes = []string{
"awskms",
"azurekeyvault",
"gcpkms",
"testing",
}

type (
Encryption struct {
// Enabled determines whether encryption is enabled. Decryption will be attempted for encrypted
// payloads regardless of this flag.
Enabled bool `yaml:"enabled"`

// Policy defines keys and rotation policies.
Policy KeyPolicy `yaml:"policy"`
}

KeyPolicy struct {
// URI is the vendor-specific URL identifying the KMS key used to encrypt DEKs.
// URIs follow the gocloud.dev/secrets URL scheme (with the exception of the testing scheme):
//
// GCP KMS: gcpkms://projects/PROJECT/locations/LOCATION/keyRings/RING/cryptoKeys/KEY
// AWS KMS: awskms:///arn:aws:kms:REGION:ACCOUNT:key/KEY-ID?region=REGION
// Azure Vault: azurekeyvault://VAULT.vault.azure.net/keys/KEY-NAME/KEY-VERSION
// Local/test: testing://smGbjm71Nxd1Ig5FS0wj9SlbzAIrnolCz9bQQ6uAhl4=
URI string `yaml:"uri"`

// RetiredURIs lists KMS keys that are no longer used to encrypt new DEKs but
// are still needed to decrypt DEKs encrypted by previous keys (e.g. after a
// provider migration). Each entry follows the same URI scheme rules as [URI].
RetiredURIs []string `yaml:"retiredURIs"`

// Duration is how long the DEK is valid before it must be rotated.
Duration time.Duration `yaml:"duration"`

// RenewBefore is how far before a DEK expires it should be proactively rotated.
RenewBefore time.Duration `yaml:"renewBefore"`
}
)

func (p *KeyPolicy) UnmarshalYAML(unmarshal func(any) error) error {
type raw KeyPolicy
var decoded raw
if err := unmarshal(&decoded); err != nil {
return err
}

*p = KeyPolicy(decoded)
return p.validURIs()
}

func (p *KeyPolicy) validURIs() error {
if err := validKeyURI(p.URI); err != nil {
if strings.TrimSpace(p.URI) != "" {
return err
}
}

for _, uri := range p.RetiredURIs {
if err := validKeyURI(uri); err != nil {
return err
}
}

return nil
}

func validKeyURI(uri string) error {
u, err := url.Parse(uri)
if err != nil {
return fmt.Errorf("failed to parse key URI: %s, %w", uri, err)
}

if !slices.Contains(validKeySchemes, strings.ToLower(u.Scheme)) {
return fmt.Errorf(
"invalid key URI: %s, valid schemes: [%s]",
uri,
strings.Join(validKeySchemes, ","),
)
}

return nil
}
59 changes: 59 additions & 0 deletions config/encryption_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package config_test

import (
"fmt"
"testing"

"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"

"github.com/temporalio/s2s-proxy/config"
)

func TestKeyPolicyURIs(t *testing.T) {
tests := []struct {
name string
uri string
wantErr bool
}{
{name: "empty"},
{name: "gcpkms", uri: "gcpkms://projects/p/locations/global/keyRings/r/cryptoKeys/k"},
{name: "awskms", uri: "awskms:///arn:aws:kms:us-east-1:123456789012:key/abc?region=us-east-1"},
{name: "azurekeyvault", uri: "azurekeyvault://my-vault.vault.azure.net/keys/my-key/v1"},
{name: "testing", uri: "testing://smGbjm71Nxd1Ig5FS0wj9SlbzAIrnolCz9bQQ6uAhl4="},
{name: "unknown scheme", uri: "hashivault://localhost/v1/transit/keys/my-key", wantErr: true},
{name: "unparseable URI", uri: "://bad", wantErr: true},
}

for _, tt := range tests {
str := "policy:\n uri: " + tt.uri

var enc config.Encryption
err := yaml.Unmarshal([]byte(str), &enc)
if tt.wantErr {
require.Error(t, err, tt.name)
continue
}

require.NoError(t, err, tt.name)
}

// Verify retired URIs are also validated
for _, tt := range tests {
tt.name += " retired"
if tt.uri == "" {
tt.wantErr = true
}

str := fmt.Sprintf("policy:\n retiredURIs:\n - \"%s\"", tt.uri)

var enc config.Encryption
err := yaml.Unmarshal([]byte(str), &enc)
if tt.wantErr {
require.Error(t, err, tt.name)
continue
}

require.NoError(t, err, tt.name)
}
}
Loading