diff --git a/chunk.go b/chunk.go index 2781ec3..8140367 100644 --- a/chunk.go +++ b/chunk.go @@ -222,27 +222,49 @@ func (this *Value) readBufferAt(chunk *Chunk, offset uint64) (uint64, error) { return 0, errors.New("Unsuported type") } -func protocolBufferFromValue(v interface{}) [][]byte { +func protocolBufferFromValue(v interface{}) ([][]byte, error) { switch v := v.(type) { case nil: - return protocolBufferFromNull() + return protocolBufferFromNull(), nil case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - return protocolBufferFromInt(v) - case float32, float64: - return protocolBufferFromFloat(v) + return protocolBufferFromInt(v), nil + case float32: + return protocolBufferFromFloat(float64(v)), nil + case float64: + return protocolBufferFromFloat(v), nil case string: - return protocolBufferFromString(v, true) + return protocolBufferFromString(v, true), nil case []byte: - return protocolBufferFromBytes(v) + return protocolBufferFromBytes(v), nil default: rv := reflect.ValueOf(v) - if rv.Kind() == reflect.Ptr { + if !rv.IsValid() { + return protocolBufferFromNull(), nil + } + if rv.Kind() == reflect.Pointer { if rv.IsNil() { - return protocolBufferFromNull() + return protocolBufferFromNull(), nil } return protocolBufferFromValue(rv.Elem().Interface()) } - return make([][]byte, 0) + + switch rv.Kind() { + case reflect.String: + return protocolBufferFromString(rv.String(), true), nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return protocolBufferFromInt(rv.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return protocolBufferFromInt(rv.Uint()), nil + case reflect.Float32, reflect.Float64: + return protocolBufferFromFloat(rv.Convert(reflect.TypeOf(float64(0))).Float()), nil + case reflect.Bool: + if rv.Bool() { + return protocolBufferFromInt(1), nil + } + return protocolBufferFromInt(0), nil + default: + return nil, fmt.Errorf("unsupported parameter type %T", v) + } } } @@ -263,14 +285,7 @@ func protocolBufferFromInt(v interface{}) [][]byte { } func protocolBufferFromFloat(v interface{}) [][]byte { - var f float64 - switch v := v.(type) { - case float32: - f = float64(v) - case float64: - f = v - } - return [][]byte{[]byte(fmt.Sprintf("%c%s ", CMD_FLOAT, strconv.FormatFloat(f, 'f', -1, 64)))} + return [][]byte{[]byte(fmt.Sprintf("%c%s ", CMD_FLOAT, strconv.FormatFloat(v.(float64), 'f', -1, 64)))} } // func protocolBufferFromFloat(v interface{}) [][]byte { @@ -386,7 +401,11 @@ func (this *SQCloud) sendArray(command string, values []interface{}) (int, error // convert values to buffers encoded with whe sqlitecloud protocol buffers := [][]byte{protocolBufferFromString(command, true)[0]} for _, v := range values { - buffers = append(buffers, protocolBufferFromValue(v)...) + valueBuffers, err := protocolBufferFromValue(v) + if err != nil { + return 0, err + } + buffers = append(buffers, valueBuffers...) } // calculate the array header diff --git a/chunk_internal_test.go b/chunk_internal_test.go index 5c2d70b..e1a5021 100644 --- a/chunk_internal_test.go +++ b/chunk_internal_test.go @@ -1,109 +1,158 @@ package sqlitecloud import ( + "fmt" + "strings" "testing" ) +type testStringEnum string +type testIntEnum int + func TestProtocolBufferFromValue(t *testing.T) { + type unsupported struct{} + intVal := 42 + strVal := "hello" + tests := []struct { - name string - value interface{} - wantLen int // expected number of []byte buffers returned - wantType byte + name string + value interface{} + wantLen int + wantType byte + wantError bool }{ - // Basic types - {"nil", nil, 1, CMD_NULL}, - {"string", "hello", 1, CMD_ZEROSTRING}, - {"int", int(42), 1, CMD_INT}, - {"int8", int8(8), 1, CMD_INT}, - {"int16", int16(16), 1, CMD_INT}, - {"int32", int32(32), 1, CMD_INT}, - {"int64", int64(64), 1, CMD_INT}, - {"float32", float32(3.14), 1, CMD_FLOAT}, - {"float64", float64(2.71), 1, CMD_FLOAT}, - {"[]byte", []byte("blob"), 2, CMD_BLOB}, // header + data - - // Unsigned integers - {"uint", uint(1), 1, CMD_INT}, - {"uint8", uint8(1), 1, CMD_INT}, - {"uint16", uint16(1), 1, CMD_INT}, - {"uint32", uint32(1), 1, CMD_INT}, - {"uint64", uint64(1), 1, CMD_INT}, - - // Pointer types (dereferenced) - {"*int", intPtr(42), 1, CMD_INT}, - {"*string", strPtr("hello"), 1, CMD_ZEROSTRING}, - {"*int nil", (*int)(nil), 1, CMD_NULL}, - {"*string nil", (*string)(nil), 1, CMD_NULL}, - - // Unsupported types still return empty buffers - {"bool", true, 0, 0}, + {"nil", nil, 1, CMD_NULL, false}, + {"string", "hello", 1, CMD_ZEROSTRING, false}, + {"int", int(42), 1, CMD_INT, false}, + {"int8", int8(8), 1, CMD_INT, false}, + {"int16", int16(16), 1, CMD_INT, false}, + {"int32", int32(32), 1, CMD_INT, false}, + {"int64", int64(64), 1, CMD_INT, false}, + {"uint", uint(1), 1, CMD_INT, false}, + {"uint8", uint8(1), 1, CMD_INT, false}, + {"uint16", uint16(1), 1, CMD_INT, false}, + {"uint32", uint32(1), 1, CMD_INT, false}, + {"uint64", uint64(1), 1, CMD_INT, false}, + {"float32", float32(3.14), 1, CMD_FLOAT, false}, + {"float64", float64(2.71), 1, CMD_FLOAT, false}, + {"[]byte", []byte("blob"), 2, CMD_BLOB, false}, + {"bool true", true, 1, CMD_INT, false}, + {"bool false", false, 1, CMD_INT, false}, + {"*int", &intVal, 1, CMD_INT, false}, + {"*string", &strVal, 1, CMD_ZEROSTRING, false}, + {"*int nil", (*int)(nil), 1, CMD_NULL, false}, + {"*string nil", (*string)(nil), 1, CMD_NULL, false}, + {"unsupported", unsupported{}, 0, 0, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - buffers := protocolBufferFromValue(tt.value) + buffers, err := protocolBufferFromValue(tt.value) + if tt.wantError { + if err == nil { + t.Fatalf("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if len(buffers) != tt.wantLen { - t.Errorf("protocolBufferFromValue(%T(%v)): got %d buffers, want %d", tt.value, tt.value, len(buffers), tt.wantLen) + t.Fatalf("got %d buffers, want %d", len(buffers), tt.wantLen) } - if tt.wantLen > 0 && len(buffers) > 0 { - if buffers[0][0] != tt.wantType { - t.Errorf("protocolBufferFromValue(%T(%v)): got type %c, want %c", tt.value, tt.value, buffers[0][0], tt.wantType) - } + if tt.wantLen > 0 && buffers[0][0] != tt.wantType { + t.Fatalf("got first buffer type %q, want %q", buffers[0][0], tt.wantType) } }) } } -func TestProtocolBufferFromValueMixedArray(t *testing.T) { - // Simulates the loop in sendArray: builds buffers from a mixed values slice - // and checks that the number of buffer groups matches the number of values. - pInt := intPtr(99) - values := []interface{}{ - "hello", // string -> 1 buffer - int(42), // int -> 1 buffer - nil, // nil -> 1 buffer - pInt, // *int -> 1 buffer (dereferenced to int) - float64(3), // float64 -> 1 buffer - uint(7), // uint -> 1 buffer - []byte("x"), // []byte -> 2 buffers (header+data) +func TestProtocolBufferFromValueSupportsStringAlias(t *testing.T) { + val := testStringEnum("active") + buffers, err := protocolBufferFromValue(val) + if err != nil { + t.Fatalf("unexpected error: %v", err) } - // Count how many values produce at least one buffer - buffersPerValue := make([]int, len(values)) - totalBuffers := 0 - missingValues := 0 + if len(buffers) != 1 { + t.Fatalf("expected 1 buffer, got %d", len(buffers)) + } + got := string(buffers[0]) + want := fmt.Sprintf("%c%d %s\x00", CMD_ZEROSTRING, len("active")+1, "active") + if got != want { + t.Fatalf("unexpected encoded value: want %q got %q", want, got) + } +} - for i, v := range values { - bufs := protocolBufferFromValue(v) - buffersPerValue[i] = len(bufs) - totalBuffers += len(bufs) - if len(bufs) == 0 { - missingValues++ - t.Errorf("value[%d] (%T = %v) produced 0 buffers — will be silently dropped", i, v, v) - } +func TestProtocolBufferFromValueSupportsIntAliasPointer(t *testing.T) { + raw := testIntEnum(7) + buffers, err := protocolBufferFromValue(&raw) + if err != nil { + t.Fatalf("unexpected error: %v", err) } - if missingValues > 0 { - t.Errorf("%d out of %d values produced no buffers and will be missing from the protocol message", missingValues, len(values)) + if len(buffers) != 1 { + t.Fatalf("expected 1 buffer, got %d", len(buffers)) } + got := string(buffers[0]) + want := fmt.Sprintf("%c%d ", CMD_INT, 7) + if got != want { + t.Fatalf("unexpected encoded value: want %q got %q", want, got) + } +} - // Reproduce the exact loop from sendArray - buffers := [][]byte{} - for _, v := range values { - buffers = append(buffers, protocolBufferFromValue(v)...) +func TestProtocolBufferFromValueSupportsFloat32(t *testing.T) { + buffers, err := protocolBufferFromValue(float32(2.5)) + if err != nil { + t.Fatalf("unexpected error: %v", err) } - t.Logf("values count: %d, total buffers: %d, buffers per value: %v", len(values), len(buffers), buffersPerValue) + if len(buffers) != 1 { + t.Fatalf("expected 1 buffer, got %d", len(buffers)) + } + got := string(buffers[0]) + if !strings.HasPrefix(got, fmt.Sprintf("%c", CMD_FLOAT)) { + t.Fatalf("expected float buffer prefix, got %q", got) + } +} + +func TestProtocolBufferFromValueUnsupportedTypeReturnsError(t *testing.T) { + type unsupported struct { + Name string + } - // Every value must produce at least 1 buffer ([]byte produces 2) - expectedMinBuffers := len(values) - if len(buffers) < expectedMinBuffers { - t.Errorf("buffers array has %d elements, expected at least %d (one per value). %d values were silently dropped.", - len(buffers), expectedMinBuffers, missingValues) + _, err := protocolBufferFromValue(unsupported{Name: "x"}) + if err == nil { + t.Fatalf("expected error for unsupported type") } } -// helpers -func intPtr(v int) *int { return &v } -func strPtr(v string) *string { return &v } +func TestProtocolBufferFromValueMixedArrayNoSilentDrops(t *testing.T) { + pInt := 99 + values := []interface{}{ + "hello", + int(42), + nil, + &pInt, + float64(3), + uint(7), + []byte("x"), + true, + } + + buffers := [][]byte{} + for i, v := range values { + valueBuffers, err := protocolBufferFromValue(v) + if err != nil { + t.Fatalf("unexpected error at index %d (%T): %v", i, v, err) + } + if len(valueBuffers) == 0 { + t.Fatalf("value at index %d produced zero buffers", i) + } + buffers = append(buffers, valueBuffers...) + } + + if len(buffers) < len(values) { + t.Fatalf("got %d total buffers, expected at least %d", len(buffers), len(values)) + } +}