diff --git a/github.com/apache/thrift/NOTICE b/github.com/apache/thrift/NOTICE index 902dc8d314..37824e7fb6 100644 --- a/github.com/apache/thrift/NOTICE +++ b/github.com/apache/thrift/NOTICE @@ -1,5 +1,5 @@ Apache Thrift -Copyright 2006-2017 The Apache Software Foundation. +Copyright (C) 2006 - 2019, The Apache Software Foundation This product includes software developed at The Apache Software Foundation (http://www.apache.org/). diff --git a/github.com/apache/thrift/lib/go/thrift/application_exception.go b/github.com/apache/thrift/lib/go/thrift/application_exception.go index b9d7eedcdd..0023c57cf1 100644 --- a/github.com/apache/thrift/lib/go/thrift/application_exception.go +++ b/github.com/apache/thrift/lib/go/thrift/application_exception.go @@ -28,6 +28,9 @@ const ( MISSING_RESULT = 5 INTERNAL_ERROR = 6 PROTOCOL_ERROR = 7 + INVALID_TRANSFORM = 8 + INVALID_PROTOCOL = 9 + UNSUPPORTED_CLIENT_TYPE = 10 ) var defaultApplicationExceptionMessage = map[int32]string{ @@ -39,6 +42,9 @@ var defaultApplicationExceptionMessage = map[int32]string{ MISSING_RESULT: "missing result", INTERNAL_ERROR: "unknown internal error", PROTOCOL_ERROR: "unknown protocol error", + INVALID_TRANSFORM: "Invalid transform", + INVALID_PROTOCOL: "Invalid protocol", + UNSUPPORTED_CLIENT_TYPE: "Unsupported client type", } // Application level Thrift exception diff --git a/github.com/apache/thrift/lib/go/thrift/binary_protocol.go b/github.com/apache/thrift/lib/go/thrift/binary_protocol.go index 1f90bf4351..93ae898cf5 100644 --- a/github.com/apache/thrift/lib/go/thrift/binary_protocol.go +++ b/github.com/apache/thrift/lib/go/thrift/binary_protocol.go @@ -32,8 +32,6 @@ import ( type TBinaryProtocol struct { trans TRichTransport origTransport TTransport - reader io.Reader - writer io.Writer strictRead bool strictWrite bool buffer [64]byte @@ -55,8 +53,6 @@ func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProt } else { p.trans = NewTRichTransport(t) } - p.reader = p.trans - p.writer = p.trans return p } @@ -192,21 +188,21 @@ func (p *TBinaryProtocol) WriteByte(value int8) error { func (p *TBinaryProtocol) WriteI16(value int16) error { v := p.buffer[0:2] binary.BigEndian.PutUint16(v, uint16(value)) - _, e := p.writer.Write(v) + _, e := p.trans.Write(v) return NewTProtocolException(e) } func (p *TBinaryProtocol) WriteI32(value int32) error { v := p.buffer[0:4] binary.BigEndian.PutUint32(v, uint32(value)) - _, e := p.writer.Write(v) + _, e := p.trans.Write(v) return NewTProtocolException(e) } func (p *TBinaryProtocol) WriteI64(value int64) error { v := p.buffer[0:8] binary.BigEndian.PutUint64(v, uint64(value)) - _, err := p.writer.Write(v) + _, err := p.trans.Write(v) return NewTProtocolException(err) } @@ -228,7 +224,7 @@ func (p *TBinaryProtocol) WriteBinary(value []byte) error { if e != nil { return e } - _, err := p.writer.Write(value) + _, err := p.trans.Write(value) return NewTProtocolException(err) } @@ -468,7 +464,7 @@ func (p *TBinaryProtocol) Transport() TTransport { } func (p *TBinaryProtocol) readAll(buf []byte) error { - _, err := io.ReadFull(p.reader, buf) + _, err := io.ReadFull(p.trans, buf) return NewTProtocolException(err) } diff --git a/github.com/apache/thrift/lib/go/thrift/client.go b/github.com/apache/thrift/lib/go/thrift/client.go index 28791ccd0c..b073a952d9 100644 --- a/github.com/apache/thrift/lib/go/thrift/client.go +++ b/github.com/apache/thrift/lib/go/thrift/client.go @@ -24,6 +24,16 @@ func NewTStandardClient(inputProtocol, outputProtocol TProtocol) *TStandardClien } func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32, method string, args TStruct) error { + // Set headers from context object on THeaderProtocol + if headerProt, ok := oprot.(*THeaderProtocol); ok { + headerProt.ClearWriteHeaders() + for _, key := range GetWriteHeaderList(ctx) { + if value, ok := GetHeader(ctx, key); ok { + headerProt.SetWriteHeader(key, value) + } + } + } + if err := oprot.WriteMessageBegin(method, CALL, seqId); err != nil { return err } diff --git a/github.com/apache/thrift/lib/go/thrift/framed_transport.go b/github.com/apache/thrift/lib/go/thrift/framed_transport.go index 81fa65aaae..34275b5f4d 100644 --- a/github.com/apache/thrift/lib/go/thrift/framed_transport.go +++ b/github.com/apache/thrift/lib/go/thrift/framed_transport.go @@ -93,7 +93,21 @@ func (p *TFramedTransport) Read(buf []byte) (l int, err error) { l, err = p.Read(tmp) copy(buf, tmp) if err == nil { - err = NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d bytes", frameSize, len(buf))) + // Note: It's important to only return an error when l + // is zero. + // In io.Reader.Read interface, it's perfectly fine to + // return partial data and nil error, which means + // "This is all the data we have right now without + // blocking. If you need the full data, call Read again + // or use io.ReadFull instead". + // Returning partial data with an error actually means + // there's no more data after the partial data just + // returned, which is not true in this case + // (it might be that the other end just haven't written + // them yet). + if l == 0 { + err = NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d bytes", frameSize, len(buf))) + } return } } diff --git a/github.com/apache/thrift/lib/go/thrift/header_context.go b/github.com/apache/thrift/lib/go/thrift/header_context.go new file mode 100644 index 0000000000..21e880d66c --- /dev/null +++ b/github.com/apache/thrift/lib/go/thrift/header_context.go @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" +) + +// See https://godoc.org/context#WithValue on why do we need the unexported typedefs. +type ( + headerKey string + headerKeyList int +) + +// Values for headerKeyList. +const ( + headerKeyListRead headerKeyList = iota + headerKeyListWrite +) + +// SetHeader sets a header in the context. +func SetHeader(ctx context.Context, key, value string) context.Context { + return context.WithValue( + ctx, + headerKey(key), + value, + ) +} + +// GetHeader returns a value of the given header from the context. +func GetHeader(ctx context.Context, key string) (value string, ok bool) { + if v := ctx.Value(headerKey(key)); v != nil { + value, ok = v.(string) + } + return +} + +// SetReadHeaderList sets the key list of read THeaders in the context. +func SetReadHeaderList(ctx context.Context, keys []string) context.Context { + return context.WithValue( + ctx, + headerKeyListRead, + keys, + ) +} + +// GetReadHeaderList returns the key list of read THeaders from the context. +func GetReadHeaderList(ctx context.Context) []string { + if v := ctx.Value(headerKeyListRead); v != nil { + if value, ok := v.([]string); ok { + return value + } + } + return nil +} + +// SetWriteHeaderList sets the key list of THeaders to write in the context. +func SetWriteHeaderList(ctx context.Context, keys []string) context.Context { + return context.WithValue( + ctx, + headerKeyListWrite, + keys, + ) +} + +// GetWriteHeaderList returns the key list of THeaders to write from the context. +func GetWriteHeaderList(ctx context.Context) []string { + if v := ctx.Value(headerKeyListWrite); v != nil { + if value, ok := v.([]string); ok { + return value + } + } + return nil +} + +// AddReadTHeaderToContext adds the whole THeader headers into context. +func AddReadTHeaderToContext(ctx context.Context, headers THeaderMap) context.Context { + keys := make([]string, 0, len(headers)) + for key, value := range headers { + ctx = SetHeader(ctx, key, value) + keys = append(keys, key) + } + return SetReadHeaderList(ctx, keys) +} diff --git a/github.com/apache/thrift/lib/go/thrift/header_protocol.go b/github.com/apache/thrift/lib/go/thrift/header_protocol.go new file mode 100644 index 0000000000..46205b28ba --- /dev/null +++ b/github.com/apache/thrift/lib/go/thrift/header_protocol.go @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" +) + +// THeaderProtocol is a thrift protocol that implements THeader: +// https://github.com/apache/thrift/blob/master/doc/specs/HeaderFormat.md +// +// It supports either binary or compact protocol as the wrapped protocol. +// +// Most of the THeader handlings are happening inside THeaderTransport. +type THeaderProtocol struct { + transport *THeaderTransport + + // Will be initialized on first read/write. + protocol TProtocol +} + +// NewTHeaderProtocol creates a new THeaderProtocol from the underlying +// transport. The passed in transport will be wrapped with THeaderTransport. +// +// Note that THeaderTransport handles frame and zlib by itself, +// so the underlying transport should be a raw socket transports (TSocket or TSSLSocket), +// instead of rich transports like TZlibTransport or TFramedTransport. +func NewTHeaderProtocol(trans TTransport) *THeaderProtocol { + t := NewTHeaderTransport(trans) + p, _ := THeaderProtocolDefault.GetProtocol(t) + return &THeaderProtocol{ + transport: t, + protocol: p, + } +} + +type tHeaderProtocolFactory struct{} + +func (tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol { + return NewTHeaderProtocol(trans) +} + +// NewTHeaderProtocolFactory creates a factory for THeader. +// +// It's a wrapper for NewTHeaderProtocol +func NewTHeaderProtocolFactory() TProtocolFactory { + return tHeaderProtocolFactory{} +} + +// Transport returns the underlying transport. +// +// It's guaranteed to be of type *THeaderTransport. +func (p *THeaderProtocol) Transport() TTransport { + return p.transport +} + +// GetReadHeaders returns the THeaderMap read from transport. +func (p *THeaderProtocol) GetReadHeaders() THeaderMap { + return p.transport.GetReadHeaders() +} + +// SetWriteHeader sets a header for write. +func (p *THeaderProtocol) SetWriteHeader(key, value string) { + p.transport.SetWriteHeader(key, value) +} + +// ClearWriteHeaders clears all write headers previously set. +func (p *THeaderProtocol) ClearWriteHeaders() { + p.transport.ClearWriteHeaders() +} + +// AddTransform add a transform for writing. +func (p *THeaderProtocol) AddTransform(transform THeaderTransformID) error { + return p.transport.AddTransform(transform) +} + +func (p *THeaderProtocol) Flush(ctx context.Context) error { + return p.transport.Flush(ctx) +} + +func (p *THeaderProtocol) WriteMessageBegin(name string, typeID TMessageType, seqID int32) error { + newProto, err := p.transport.Protocol().GetProtocol(p.transport) + if err != nil { + return err + } + p.protocol = newProto + p.transport.SequenceID = seqID + return p.protocol.WriteMessageBegin(name, typeID, seqID) +} + +func (p *THeaderProtocol) WriteMessageEnd() error { + if err := p.protocol.WriteMessageEnd(); err != nil { + return err + } + return p.transport.Flush(context.Background()) +} + +func (p *THeaderProtocol) WriteStructBegin(name string) error { + return p.protocol.WriteStructBegin(name) +} + +func (p *THeaderProtocol) WriteStructEnd() error { + return p.protocol.WriteStructEnd() +} + +func (p *THeaderProtocol) WriteFieldBegin(name string, typeID TType, id int16) error { + return p.protocol.WriteFieldBegin(name, typeID, id) +} + +func (p *THeaderProtocol) WriteFieldEnd() error { + return p.protocol.WriteFieldEnd() +} + +func (p *THeaderProtocol) WriteFieldStop() error { + return p.protocol.WriteFieldStop() +} + +func (p *THeaderProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { + return p.protocol.WriteMapBegin(keyType, valueType, size) +} + +func (p *THeaderProtocol) WriteMapEnd() error { + return p.protocol.WriteMapEnd() +} + +func (p *THeaderProtocol) WriteListBegin(elemType TType, size int) error { + return p.protocol.WriteListBegin(elemType, size) +} + +func (p *THeaderProtocol) WriteListEnd() error { + return p.protocol.WriteListEnd() +} + +func (p *THeaderProtocol) WriteSetBegin(elemType TType, size int) error { + return p.protocol.WriteSetBegin(elemType, size) +} + +func (p *THeaderProtocol) WriteSetEnd() error { + return p.protocol.WriteSetEnd() +} + +func (p *THeaderProtocol) WriteBool(value bool) error { + return p.protocol.WriteBool(value) +} + +func (p *THeaderProtocol) WriteByte(value int8) error { + return p.protocol.WriteByte(value) +} + +func (p *THeaderProtocol) WriteI16(value int16) error { + return p.protocol.WriteI16(value) +} + +func (p *THeaderProtocol) WriteI32(value int32) error { + return p.protocol.WriteI32(value) +} + +func (p *THeaderProtocol) WriteI64(value int64) error { + return p.protocol.WriteI64(value) +} + +func (p *THeaderProtocol) WriteDouble(value float64) error { + return p.protocol.WriteDouble(value) +} + +func (p *THeaderProtocol) WriteString(value string) error { + return p.protocol.WriteString(value) +} + +func (p *THeaderProtocol) WriteBinary(value []byte) error { + return p.protocol.WriteBinary(value) +} + +// ReadFrame calls underlying THeaderTransport's ReadFrame function. +func (p *THeaderProtocol) ReadFrame() error { + return p.transport.ReadFrame() +} + +func (p *THeaderProtocol) ReadMessageBegin() (name string, typeID TMessageType, seqID int32, err error) { + if err = p.transport.ReadFrame(); err != nil { + return + } + + var newProto TProtocol + newProto, err = p.transport.Protocol().GetProtocol(p.transport) + if err != nil { + tAppExc, ok := err.(TApplicationException) + if !ok { + return + } + if e := p.protocol.WriteMessageBegin("", EXCEPTION, seqID); e != nil { + return + } + if e := tAppExc.Write(p.protocol); e != nil { + return + } + if e := p.protocol.WriteMessageEnd(); e != nil { + return + } + if e := p.transport.Flush(context.Background()); e != nil { + return + } + return + } + p.protocol = newProto + + return p.protocol.ReadMessageBegin() +} + +func (p *THeaderProtocol) ReadMessageEnd() error { + return p.protocol.ReadMessageEnd() +} + +func (p *THeaderProtocol) ReadStructBegin() (name string, err error) { + return p.protocol.ReadStructBegin() +} + +func (p *THeaderProtocol) ReadStructEnd() error { + return p.protocol.ReadStructEnd() +} + +func (p *THeaderProtocol) ReadFieldBegin() (name string, typeID TType, id int16, err error) { + return p.protocol.ReadFieldBegin() +} + +func (p *THeaderProtocol) ReadFieldEnd() error { + return p.protocol.ReadFieldEnd() +} + +func (p *THeaderProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) { + return p.protocol.ReadMapBegin() +} + +func (p *THeaderProtocol) ReadMapEnd() error { + return p.protocol.ReadMapEnd() +} + +func (p *THeaderProtocol) ReadListBegin() (elemType TType, size int, err error) { + return p.protocol.ReadListBegin() +} + +func (p *THeaderProtocol) ReadListEnd() error { + return p.protocol.ReadListEnd() +} + +func (p *THeaderProtocol) ReadSetBegin() (elemType TType, size int, err error) { + return p.protocol.ReadSetBegin() +} + +func (p *THeaderProtocol) ReadSetEnd() error { + return p.protocol.ReadSetEnd() +} + +func (p *THeaderProtocol) ReadBool() (value bool, err error) { + return p.protocol.ReadBool() +} + +func (p *THeaderProtocol) ReadByte() (value int8, err error) { + return p.protocol.ReadByte() +} + +func (p *THeaderProtocol) ReadI16() (value int16, err error) { + return p.protocol.ReadI16() +} + +func (p *THeaderProtocol) ReadI32() (value int32, err error) { + return p.protocol.ReadI32() +} + +func (p *THeaderProtocol) ReadI64() (value int64, err error) { + return p.protocol.ReadI64() +} + +func (p *THeaderProtocol) ReadDouble() (value float64, err error) { + return p.protocol.ReadDouble() +} + +func (p *THeaderProtocol) ReadString() (value string, err error) { + return p.protocol.ReadString() +} + +func (p *THeaderProtocol) ReadBinary() (value []byte, err error) { + return p.protocol.ReadBinary() +} + +func (p *THeaderProtocol) Skip(fieldType TType) error { + return p.protocol.Skip(fieldType) +} diff --git a/github.com/apache/thrift/lib/go/thrift/header_transport.go b/github.com/apache/thrift/lib/go/thrift/header_transport.go new file mode 100644 index 0000000000..5343ccb46b --- /dev/null +++ b/github.com/apache/thrift/lib/go/thrift/header_transport.go @@ -0,0 +1,723 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bufio" + "bytes" + "compress/zlib" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" +) + +// Size in bytes for 32-bit ints. +const size32 = 4 + +type headerMeta struct { + MagicFlags uint32 + SequenceID int32 + HeaderLength uint16 +} + +const headerMetaSize = 10 + +type clientType int + +const ( + clientUnknown clientType = iota + clientHeaders + clientFramedBinary + clientUnframedBinary + clientFramedCompact + clientUnframedCompact +) + +// Constants defined in THeader format: +// https://github.com/apache/thrift/blob/master/doc/specs/HeaderFormat.md +const ( + THeaderHeaderMagic uint32 = 0x0fff0000 + THeaderHeaderMask uint32 = 0xffff0000 + THeaderFlagsMask uint32 = 0x0000ffff + THeaderMaxFrameSize uint32 = 0x3fffffff +) + +// THeaderMap is the type of the header map in THeader transport. +type THeaderMap map[string]string + +// THeaderProtocolID is the wrapped protocol id used in THeader. +type THeaderProtocolID int32 + +// Supported THeaderProtocolID values. +const ( + THeaderProtocolBinary THeaderProtocolID = 0x00 + THeaderProtocolCompact THeaderProtocolID = 0x02 + THeaderProtocolDefault = THeaderProtocolBinary +) + +// GetProtocol gets the corresponding TProtocol from the wrapped protocol id. +func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) { + switch id { + default: + return nil, NewTApplicationException( + INVALID_PROTOCOL, + fmt.Sprintf("THeader protocol id %d not supported", id), + ) + case THeaderProtocolBinary: + return NewTBinaryProtocolFactoryDefault().GetProtocol(trans), nil + case THeaderProtocolCompact: + return NewTCompactProtocol(trans), nil + } +} + +// THeaderTransformID defines the numeric id of the transform used. +type THeaderTransformID int32 + +// THeaderTransformID values +const ( + TransformNone THeaderTransformID = iota // 0, no special handling + TransformZlib // 1, zlib + // Rest of the values are not currently supported, namely HMAC and Snappy. +) + +var supportedTransformIDs = map[THeaderTransformID]bool{ + TransformNone: true, + TransformZlib: true, +} + +// TransformReader is an io.ReadCloser that handles transforms reading. +type TransformReader struct { + io.Reader + + closers []io.Closer +} + +var _ io.ReadCloser = (*TransformReader)(nil) + +// NewTransformReaderWithCapacity initializes a TransformReader with expected +// closers capacity. +// +// If you don't know the closers capacity beforehand, just use +// +// &TransformReader{Reader: baseReader} +// +// instead would be sufficient. +func NewTransformReaderWithCapacity(baseReader io.Reader, capacity int) *TransformReader { + return &TransformReader{ + Reader: baseReader, + closers: make([]io.Closer, 0, capacity), + } +} + +// Close calls the underlying closers in appropriate order, +// stops at and returns the first error encountered. +func (tr *TransformReader) Close() error { + // Call closers in reversed order + for i := len(tr.closers) - 1; i >= 0; i-- { + if err := tr.closers[i].Close(); err != nil { + return err + } + } + return nil +} + +// AddTransform adds a transform. +func (tr *TransformReader) AddTransform(id THeaderTransformID) error { + switch id { + default: + return NewTApplicationException( + INVALID_TRANSFORM, + fmt.Sprintf("THeaderTransformID %d not supported", id), + ) + case TransformNone: + // no-op + case TransformZlib: + readCloser, err := zlib.NewReader(tr.Reader) + if err != nil { + return err + } + tr.Reader = readCloser + tr.closers = append(tr.closers, readCloser) + } + return nil +} + +// TransformWriter is an io.WriteCloser that handles transforms writing. +type TransformWriter struct { + io.Writer + + closers []io.Closer +} + +var _ io.WriteCloser = (*TransformWriter)(nil) + +// NewTransformWriter creates a new TransformWriter with base writer and transforms. +func NewTransformWriter(baseWriter io.Writer, transforms []THeaderTransformID) (io.WriteCloser, error) { + writer := &TransformWriter{ + Writer: baseWriter, + closers: make([]io.Closer, 0, len(transforms)), + } + for _, id := range transforms { + if err := writer.AddTransform(id); err != nil { + return nil, err + } + } + return writer, nil +} + +// Close calls the underlying closers in appropriate order, +// stops at and returns the first error encountered. +func (tw *TransformWriter) Close() error { + // Call closers in reversed order + for i := len(tw.closers) - 1; i >= 0; i-- { + if err := tw.closers[i].Close(); err != nil { + return err + } + } + return nil +} + +// AddTransform adds a transform. +func (tw *TransformWriter) AddTransform(id THeaderTransformID) error { + switch id { + default: + return NewTApplicationException( + INVALID_TRANSFORM, + fmt.Sprintf("THeaderTransformID %d not supported", id), + ) + case TransformNone: + // no-op + case TransformZlib: + writeCloser := zlib.NewWriter(tw.Writer) + tw.Writer = writeCloser + tw.closers = append(tw.closers, writeCloser) + } + return nil +} + +// THeaderInfoType is the type id of the info headers. +type THeaderInfoType int32 + +// Supported THeaderInfoType values. +const ( + _ THeaderInfoType = iota // Skip 0 + InfoKeyValue // 1 + // Rest of the info types are not supported. +) + +// THeaderTransport is a Transport mode that implements THeader. +// +// Note that THeaderTransport handles frame and zlib by itself, +// so the underlying transport should be a raw socket transports (TSocket or TSSLSocket), +// instead of rich transports like TZlibTransport or TFramedTransport. +type THeaderTransport struct { + SequenceID int32 + Flags uint32 + + transport TTransport + + // THeaderMap for read and write + readHeaders THeaderMap + writeHeaders THeaderMap + + // Reading related variables. + reader *bufio.Reader + // When frame is detected, we read the frame fully into frameBuffer. + frameBuffer bytes.Buffer + // When it's non-nil, Read should read from frameReader instead of + // reader, and EOF error indicates end of frame instead of end of all + // transport. + frameReader io.ReadCloser + + // Writing related variables + writeBuffer bytes.Buffer + writeTransforms []THeaderTransformID + + clientType clientType + protocolID THeaderProtocolID + + // buffer is used in the following scenarios to avoid repetitive + // allocations, while 4 is big enough for all those scenarios: + // + // * header padding (max size 4) + // * write the frame size (size 4) + buffer [4]byte +} + +var _ TTransport = (*THeaderTransport)(nil) + +// NewTHeaderTransport creates THeaderTransport from the underlying transport. +// +// Please note that THeaderTransport handles framing and zlib by itself, +// so the underlying transport should be the raw socket transports (TSocket or TSSLSocket), +// instead of rich transports like TZlibTransport or TFramedTransport. +// +// If trans is already a *THeaderTransport, it will be returned as is. +func NewTHeaderTransport(trans TTransport) *THeaderTransport { + if ht, ok := trans.(*THeaderTransport); ok { + return ht + } + return &THeaderTransport{ + transport: trans, + reader: bufio.NewReader(trans), + writeHeaders: make(THeaderMap), + protocolID: THeaderProtocolDefault, + } +} + +// Open calls the underlying transport's Open function. +func (t *THeaderTransport) Open() error { + return t.transport.Open() +} + +// IsOpen calls the underlying transport's IsOpen function. +func (t *THeaderTransport) IsOpen() bool { + return t.transport.IsOpen() +} + +// ReadFrame tries to read the frame header, guess the client type, and handle +// unframed clients. +func (t *THeaderTransport) ReadFrame() error { + if !t.needReadFrame() { + // No need to read frame, skipping. + return nil + } + // Peek and handle the first 32 bits. + // They could either be the length field of a framed message, + // or the first bytes of an unframed message. + buf, err := t.reader.Peek(size32) + if err != nil { + return err + } + frameSize := binary.BigEndian.Uint32(buf) + if frameSize&VERSION_MASK == VERSION_1 { + t.clientType = clientUnframedBinary + return nil + } + if buf[0] == COMPACT_PROTOCOL_ID && buf[1]&COMPACT_VERSION_MASK == COMPACT_VERSION { + t.clientType = clientUnframedCompact + return nil + } + + // At this point it should be a framed message, + // sanity check on frameSize then discard the peeked part. + if frameSize > THeaderMaxFrameSize { + return NewTProtocolExceptionWithType( + SIZE_LIMIT, + errors.New("frame too large"), + ) + } + t.reader.Discard(size32) + + // Read the frame fully into frameBuffer. + _, err = io.Copy( + &t.frameBuffer, + io.LimitReader(t.reader, int64(frameSize)), + ) + if err != nil { + return err + } + t.frameReader = ioutil.NopCloser(&t.frameBuffer) + + // Peek and handle the next 32 bits. + buf = t.frameBuffer.Bytes()[:size32] + version := binary.BigEndian.Uint32(buf) + if version&THeaderHeaderMask == THeaderHeaderMagic { + t.clientType = clientHeaders + return t.parseHeaders(frameSize) + } + if version&VERSION_MASK == VERSION_1 { + t.clientType = clientFramedBinary + return nil + } + if buf[0] == COMPACT_PROTOCOL_ID && buf[1]&COMPACT_VERSION_MASK == COMPACT_VERSION { + t.clientType = clientFramedCompact + return nil + } + if err := t.endOfFrame(); err != nil { + return err + } + return NewTProtocolExceptionWithType( + NOT_IMPLEMENTED, + errors.New("unsupported client transport type"), + ) +} + +// endOfFrame does end of frame handling. +// +// It closes frameReader, and also resets frame related states. +func (t *THeaderTransport) endOfFrame() error { + defer func() { + t.frameBuffer.Reset() + t.frameReader = nil + }() + return t.frameReader.Close() +} + +func (t *THeaderTransport) parseHeaders(frameSize uint32) error { + if t.clientType != clientHeaders { + return nil + } + + var err error + var meta headerMeta + if err = binary.Read(&t.frameBuffer, binary.BigEndian, &meta); err != nil { + return err + } + frameSize -= headerMetaSize + t.Flags = meta.MagicFlags & THeaderFlagsMask + t.SequenceID = meta.SequenceID + headerLength := int64(meta.HeaderLength) * 4 + if int64(frameSize) < headerLength { + return NewTProtocolExceptionWithType( + SIZE_LIMIT, + errors.New("header size is larger than the whole frame"), + ) + } + headerBuf := NewTMemoryBuffer() + _, err = io.Copy(headerBuf, io.LimitReader(&t.frameBuffer, headerLength)) + if err != nil { + return err + } + hp := NewTCompactProtocol(headerBuf) + + // At this point the header is already read into headerBuf, + // and t.frameBuffer starts from the actual payload. + protoID, err := hp.readVarint32() + if err != nil { + return err + } + t.protocolID = THeaderProtocolID(protoID) + var transformCount int32 + transformCount, err = hp.readVarint32() + if err != nil { + return err + } + if transformCount > 0 { + reader := NewTransformReaderWithCapacity( + &t.frameBuffer, + int(transformCount), + ) + t.frameReader = reader + transformIDs := make([]THeaderTransformID, transformCount) + for i := 0; i < int(transformCount); i++ { + id, err := hp.readVarint32() + if err != nil { + return err + } + transformIDs[i] = THeaderTransformID(id) + } + // The transform IDs on the wire was added based on the order of + // writing, so on the reading side we need to reverse the order. + for i := transformCount - 1; i >= 0; i-- { + id := transformIDs[i] + if err := reader.AddTransform(id); err != nil { + return err + } + } + } + + // The info part does not use the transforms yet, so it's + // important to continue using headerBuf. + headers := make(THeaderMap) + for { + infoType, err := hp.readVarint32() + if err == io.EOF { + break + } + if err != nil { + return err + } + if THeaderInfoType(infoType) == InfoKeyValue { + count, err := hp.readVarint32() + if err != nil { + return err + } + for i := 0; i < int(count); i++ { + key, err := hp.ReadString() + if err != nil { + return err + } + value, err := hp.ReadString() + if err != nil { + return err + } + headers[key] = value + } + } else { + // Skip reading info section on the first + // unsupported info type. + break + } + } + t.readHeaders = headers + + return nil +} + +func (t *THeaderTransport) needReadFrame() bool { + if t.clientType == clientUnknown { + // This is a new connection that's never read before. + return true + } + if t.isFramed() && t.frameReader == nil { + // We just finished the last frame. + return true + } + return false +} + +func (t *THeaderTransport) Read(p []byte) (read int, err error) { + err = t.ReadFrame() + if err != nil { + return + } + if t.frameReader != nil { + read, err = t.frameReader.Read(p) + if err == io.EOF { + err = t.endOfFrame() + if err != nil { + return + } + if read < len(p) { + var nextRead int + nextRead, err = t.Read(p[read:]) + read += nextRead + } + } + return + } + return t.reader.Read(p) +} + +// Write writes data to the write buffer. +// +// You need to call Flush to actually write them to the transport. +func (t *THeaderTransport) Write(p []byte) (int, error) { + return t.writeBuffer.Write(p) +} + +// Flush writes the appropriate header and the write buffer to the underlying transport. +func (t *THeaderTransport) Flush(ctx context.Context) error { + if t.writeBuffer.Len() == 0 { + return nil + } + + defer t.writeBuffer.Reset() + + switch t.clientType { + default: + fallthrough + case clientUnknown: + t.clientType = clientHeaders + fallthrough + case clientHeaders: + headers := NewTMemoryBuffer() + hp := NewTCompactProtocol(headers) + if _, err := hp.writeVarint32(int32(t.protocolID)); err != nil { + return NewTTransportExceptionFromError(err) + } + if _, err := hp.writeVarint32(int32(len(t.writeTransforms))); err != nil { + return NewTTransportExceptionFromError(err) + } + for _, transform := range t.writeTransforms { + if _, err := hp.writeVarint32(int32(transform)); err != nil { + return NewTTransportExceptionFromError(err) + } + } + if len(t.writeHeaders) > 0 { + if _, err := hp.writeVarint32(int32(InfoKeyValue)); err != nil { + return NewTTransportExceptionFromError(err) + } + if _, err := hp.writeVarint32(int32(len(t.writeHeaders))); err != nil { + return NewTTransportExceptionFromError(err) + } + for key, value := range t.writeHeaders { + if err := hp.WriteString(key); err != nil { + return NewTTransportExceptionFromError(err) + } + if err := hp.WriteString(value); err != nil { + return NewTTransportExceptionFromError(err) + } + } + } + padding := 4 - headers.Len()%4 + if padding < 4 { + buf := t.buffer[:padding] + for i := range buf { + buf[i] = 0 + } + if _, err := headers.Write(buf); err != nil { + return NewTTransportExceptionFromError(err) + } + } + + var payload bytes.Buffer + meta := headerMeta{ + MagicFlags: THeaderHeaderMagic + t.Flags&THeaderFlagsMask, + SequenceID: t.SequenceID, + HeaderLength: uint16(headers.Len() / 4), + } + if err := binary.Write(&payload, binary.BigEndian, meta); err != nil { + return NewTTransportExceptionFromError(err) + } + if _, err := io.Copy(&payload, headers); err != nil { + return NewTTransportExceptionFromError(err) + } + + writer, err := NewTransformWriter(&payload, t.writeTransforms) + if err != nil { + return NewTTransportExceptionFromError(err) + } + if _, err := io.Copy(writer, &t.writeBuffer); err != nil { + return NewTTransportExceptionFromError(err) + } + if err := writer.Close(); err != nil { + return NewTTransportExceptionFromError(err) + } + + // First write frame length + buf := t.buffer[:size32] + binary.BigEndian.PutUint32(buf, uint32(payload.Len())) + if _, err := t.transport.Write(buf); err != nil { + return NewTTransportExceptionFromError(err) + } + // Then write the payload + if _, err := io.Copy(t.transport, &payload); err != nil { + return NewTTransportExceptionFromError(err) + } + + case clientFramedBinary, clientFramedCompact: + buf := t.buffer[:size32] + binary.BigEndian.PutUint32(buf, uint32(t.writeBuffer.Len())) + if _, err := t.transport.Write(buf); err != nil { + return NewTTransportExceptionFromError(err) + } + fallthrough + case clientUnframedBinary, clientUnframedCompact: + if _, err := io.Copy(t.transport, &t.writeBuffer); err != nil { + return NewTTransportExceptionFromError(err) + } + } + + select { + default: + case <-ctx.Done(): + return NewTTransportExceptionFromError(ctx.Err()) + } + + return t.transport.Flush(ctx) +} + +// Close closes the transport, along with its underlying transport. +func (t *THeaderTransport) Close() error { + if err := t.Flush(context.Background()); err != nil { + return err + } + return t.transport.Close() +} + +// RemainingBytes calls underlying transport's RemainingBytes. +// +// Even in framed cases, because of all the possible compression transforms +// involved, the remaining frame size is likely to be different from the actual +// remaining readable bytes, so we don't bother to keep tracking the remaining +// frame size by ourselves and just use the underlying transport's +// RemainingBytes directly. +func (t *THeaderTransport) RemainingBytes() uint64 { + return t.transport.RemainingBytes() +} + +// GetReadHeaders returns the THeaderMap read from transport. +func (t *THeaderTransport) GetReadHeaders() THeaderMap { + return t.readHeaders +} + +// SetWriteHeader sets a header for write. +func (t *THeaderTransport) SetWriteHeader(key, value string) { + t.writeHeaders[key] = value +} + +// ClearWriteHeaders clears all write headers previously set. +func (t *THeaderTransport) ClearWriteHeaders() { + t.writeHeaders = make(THeaderMap) +} + +// AddTransform add a transform for writing. +func (t *THeaderTransport) AddTransform(transform THeaderTransformID) error { + if !supportedTransformIDs[transform] { + return NewTProtocolExceptionWithType( + NOT_IMPLEMENTED, + fmt.Errorf("THeaderTransformID %d not supported", transform), + ) + } + t.writeTransforms = append(t.writeTransforms, transform) + return nil +} + +// Protocol returns the wrapped protocol id used in this THeaderTransport. +func (t *THeaderTransport) Protocol() THeaderProtocolID { + switch t.clientType { + default: + return t.protocolID + case clientFramedBinary, clientUnframedBinary: + return THeaderProtocolBinary + case clientFramedCompact, clientUnframedCompact: + return THeaderProtocolCompact + } +} + +func (t *THeaderTransport) isFramed() bool { + switch t.clientType { + default: + return false + case clientHeaders, clientFramedBinary, clientFramedCompact: + return true + } +} + +// THeaderTransportFactory is a TTransportFactory implementation to create +// THeaderTransport. +type THeaderTransportFactory struct { + // The underlying factory, could be nil. + Factory TTransportFactory +} + +// NewTHeaderTransportFactory creates a new *THeaderTransportFactory. +func NewTHeaderTransportFactory(factory TTransportFactory) TTransportFactory { + return &THeaderTransportFactory{ + Factory: factory, + } +} + +// GetTransport implements TTransportFactory. +func (f *THeaderTransportFactory) GetTransport(trans TTransport) (TTransport, error) { + if f.Factory != nil { + t, err := f.Factory.GetTransport(trans) + if err != nil { + return nil, err + } + return NewTHeaderTransport(t), nil + } + return NewTHeaderTransport(trans), nil +} diff --git a/github.com/apache/thrift/lib/go/thrift/json_protocol.go b/github.com/apache/thrift/lib/go/thrift/json_protocol.go index 7be685d43f..800ac22c7b 100644 --- a/github.com/apache/thrift/lib/go/thrift/json_protocol.go +++ b/github.com/apache/thrift/lib/go/thrift/json_protocol.go @@ -32,10 +32,7 @@ const ( // for references to _ParseContext see tsimplejson_protocol.go // JSON protocol implementation for thrift. -// -// This protocol produces/consumes a simple output format -// suitable for parsing by scripting languages. It should not be -// confused with the full-featured TJSONProtocol. +// Utilizes Simple JSON protocol // type TJSONProtocol struct { *TSimpleJSONProtocol diff --git a/github.com/apache/thrift/lib/go/thrift/pointerize.go b/github.com/apache/thrift/lib/go/thrift/pointerize.go index 8d6b2c2159..fb564ea819 100644 --- a/github.com/apache/thrift/lib/go/thrift/pointerize.go +++ b/github.com/apache/thrift/lib/go/thrift/pointerize.go @@ -41,6 +41,8 @@ package thrift func Float32Ptr(v float32) *float32 { return &v } func Float64Ptr(v float64) *float64 { return &v } func IntPtr(v int) *int { return &v } +func Int8Ptr(v int8) *int8 { return &v } +func Int16Ptr(v int16) *int16 { return &v } func Int32Ptr(v int32) *int32 { return &v } func Int64Ptr(v int64) *int64 { return &v } func StringPtr(v string) *string { return &v } diff --git a/github.com/apache/thrift/lib/go/thrift/protocol.go b/github.com/apache/thrift/lib/go/thrift/protocol.go index 615b7a4a8f..2e6bc4b161 100644 --- a/github.com/apache/thrift/lib/go/thrift/protocol.go +++ b/github.com/apache/thrift/lib/go/thrift/protocol.go @@ -96,8 +96,6 @@ func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) { } switch fieldType { - case STOP: - return case BOOL: _, err = self.ReadBool() return diff --git a/github.com/apache/thrift/lib/go/thrift/simple_json_protocol.go b/github.com/apache/thrift/lib/go/thrift/simple_json_protocol.go index 2e8a71112a..f5e0c05d18 100644 --- a/github.com/apache/thrift/lib/go/thrift/simple_json_protocol.go +++ b/github.com/apache/thrift/lib/go/thrift/simple_json_protocol.go @@ -60,7 +60,7 @@ func (p _ParseContext) String() string { return "UNKNOWN-PARSE-CONTEXT" } -// JSON protocol implementation for thrift. +// Simple JSON protocol implementation for thrift. // // This protocol produces/consumes a simple output format // suitable for parsing by scripting languages. It should not be @@ -1316,7 +1316,7 @@ func (p *TSimpleJSONProtocol) readNumeric() (Numeric, error) { func (p *TSimpleJSONProtocol) safePeekContains(b []byte) bool { for i := 0; i < len(b); i++ { a, _ := p.reader.Peek(i + 1) - if len(a) == 0 || a[i] != b[i] { + if len(a) < (i+1) || a[i] != b[i] { return false } } diff --git a/github.com/apache/thrift/lib/go/thrift/simple_server.go b/github.com/apache/thrift/lib/go/thrift/simple_server.go index 6035802516..f8efbed914 100644 --- a/github.com/apache/thrift/lib/go/thrift/simple_server.go +++ b/github.com/apache/thrift/lib/go/thrift/simple_server.go @@ -42,6 +42,9 @@ type TSimpleServer struct { outputTransportFactory TTransportFactory inputProtocolFactory TProtocolFactory outputProtocolFactory TProtocolFactory + + // Headers to auto forward in THeaderProtocol + forwardHeaders []string } func NewTSimpleServer2(processor TProcessor, serverTransport TServerTransport) *TSimpleServer { @@ -125,6 +128,26 @@ func (p *TSimpleServer) Listen() error { return p.serverTransport.Listen() } +// SetForwardHeaders sets the list of header keys that will be auto forwarded +// while using THeaderProtocol. +// +// "forward" means that when the server is also a client to other upstream +// thrift servers, the context object user gets in the processor functions will +// have both read and write headers set, with write headers being forwarded. +// Users can always override the write headers by calling SetWriteHeaderList +// before calling thrift client functions. +func (p *TSimpleServer) SetForwardHeaders(headers []string) { + size := len(headers) + if size == 0 { + p.forwardHeaders = nil + return + } + + keys := make([]string, size) + copy(keys, headers) + p.forwardHeaders = keys +} + func (p *TSimpleServer) innerAccept() (int32, error) { client, err := p.serverTransport.Accept() p.mu.Lock() @@ -187,12 +210,25 @@ func (p *TSimpleServer) processRequests(client TTransport) error { if err != nil { return err } - outputTransport, err := p.outputTransportFactory.GetTransport(client) - if err != nil { - return err - } inputProtocol := p.inputProtocolFactory.GetProtocol(inputTransport) - outputProtocol := p.outputProtocolFactory.GetProtocol(outputTransport) + var outputTransport TTransport + var outputProtocol TProtocol + + // for THeaderProtocol, we must use the same protocol instance for + // input and output so that the response is in the same dialect that + // the server detected the request was in. + headerProtocol, ok := inputProtocol.(*THeaderProtocol) + if ok { + outputProtocol = inputProtocol + } else { + oTrans, err := p.outputTransportFactory.GetTransport(client) + if err != nil { + return err + } + outputTransport = oTrans + outputProtocol = p.outputProtocolFactory.GetProtocol(outputTransport) + } + defer func() { if e := recover(); e != nil { log.Printf("panic in processor: %s: %s", e, debug.Stack()) @@ -210,7 +246,22 @@ func (p *TSimpleServer) processRequests(client TTransport) error { return nil } - ok, err := processor.Process(defaultCtx, inputProtocol, outputProtocol) + ctx := defaultCtx + if headerProtocol != nil { + // We need to call ReadFrame here, otherwise we won't + // get any headers on the AddReadTHeaderToContext call. + // + // ReadFrame is safe to be called multiple times so it + // won't break when it's called again later when we + // actually start to read the message. + if err := headerProtocol.ReadFrame(); err != nil { + return err + } + ctx = AddReadTHeaderToContext(defaultCtx, headerProtocol.GetReadHeaders()) + ctx = SetWriteHeaderList(ctx, p.forwardHeaders) + } + + ok, err := processor.Process(ctx, inputProtocol, outputProtocol) if err, ok := err.(TTransportException); ok && err.TypeId() == END_OF_FILE { return nil } else if err != nil { diff --git a/github.com/apache/thrift/lib/go/thrift/socket.go b/github.com/apache/thrift/lib/go/thrift/socket.go index 8854279651..88b98f5916 100644 --- a/github.com/apache/thrift/lib/go/thrift/socket.go +++ b/github.com/apache/thrift/lib/go/thrift/socket.go @@ -162,5 +162,5 @@ func (p *TSocket) Interrupt() error { func (p *TSocket) RemainingBytes() (num_bytes uint64) { const maxSize = ^uint64(0) - return maxSize // the thruth is, we just don't know unless framed is used + return maxSize // the truth is, we just don't know unless framed is used } diff --git a/modules.txt b/modules.txt index e405f48474..52cb54b191 100644 --- a/modules.txt +++ b/modules.txt @@ -120,7 +120,7 @@ github.com/apache/arrow/go/arrow/float16 github.com/apache/arrow/go/arrow/internal/cpu github.com/apache/arrow/go/arrow/internal/debug github.com/apache/arrow/go/arrow/memory -# github.com/apache/thrift v0.0.0-20181211084444-2b7365c54f82 +# github.com/apache/thrift v0.13.0 ## explicit github.com/apache/thrift/lib/go/thrift # github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e