diff --git a/client/rpcdataset.go b/client/rpcdataset.go index da97ef6..e86adfa 100644 --- a/client/rpcdataset.go +++ b/client/rpcdataset.go @@ -319,6 +319,10 @@ func (s *IoTDBRpcDataSet) getBooleanByTsBlockColumnIndex(tsBlockColumnIndex int3 if err := s.checkRecord(); err != nil { return false, err } + if tsBlockColumnIndex < 0 { + s.lastReadWasNull = false + return false, nil + } if !s.isNull(tsBlockColumnIndex, s.tsBlockIndex) { s.lastReadWasNull = false return s.curTsBlock.GetColumn(tsBlockColumnIndex).GetBoolean(s.tsBlockIndex) @@ -348,6 +352,10 @@ func (s *IoTDBRpcDataSet) getDoubleByTsBlockColumnIndex(tsBlockColumnIndex int32 if err := s.checkRecord(); err != nil { return 0, err } + if tsBlockColumnIndex < 0 { + s.lastReadWasNull = false + return 0, nil + } if !s.isNull(tsBlockColumnIndex, s.tsBlockIndex) { s.lastReadWasNull = false return s.curTsBlock.GetColumn(tsBlockColumnIndex).GetDouble(s.tsBlockIndex) @@ -377,6 +385,10 @@ func (s *IoTDBRpcDataSet) getFloatByTsBlockColumnIndex(tsBlockColumnIndex int32) if err := s.checkRecord(); err != nil { return 0, err } + if tsBlockColumnIndex < 0 { + s.lastReadWasNull = false + return 0, nil + } if !s.isNull(tsBlockColumnIndex, s.tsBlockIndex) { s.lastReadWasNull = false return s.curTsBlock.GetColumn(tsBlockColumnIndex).GetFloat(s.tsBlockIndex) @@ -406,6 +418,10 @@ func (s *IoTDBRpcDataSet) getIntByTsBlockColumnIndex(tsBlockColumnIndex int32) ( if err := s.checkRecord(); err != nil { return 0, err } + if tsBlockColumnIndex < 0 { + s.lastReadWasNull = false + return 0, nil + } if !s.isNull(tsBlockColumnIndex, s.tsBlockIndex) { s.lastReadWasNull = false dataType := s.curTsBlock.GetColumn(tsBlockColumnIndex).GetDataType() @@ -484,6 +500,10 @@ func (s *IoTDBRpcDataSet) getBinaryByTsBlockColumnIndex(tsBlockColumnIndex int32 if err := s.checkRecord(); err != nil { return nil, err } + if tsBlockColumnIndex < 0 { + s.lastReadWasNull = false + return nil, nil + } if !s.isNull(tsBlockColumnIndex, s.tsBlockIndex) { s.lastReadWasNull = false return s.curTsBlock.GetColumn(tsBlockColumnIndex).GetBinary(s.tsBlockIndex) @@ -513,6 +533,10 @@ func (s *IoTDBRpcDataSet) getObjectByTsBlockIndex(tsBlockColumnIndex int32) (int if err := s.checkRecord(); err != nil { return nil, err } + if tsBlockColumnIndex < 0 { + s.lastReadWasNull = false + return nil, nil + } if s.isNull(tsBlockColumnIndex, s.tsBlockIndex) { s.lastReadWasNull = true return nil, nil diff --git a/client/session.go b/client/session.go index 28b326e..eca7494 100644 --- a/client/session.go +++ b/client/session.go @@ -524,6 +524,17 @@ func (s *Session) ExecuteStatement(sql string) (*SessionDataSet, error) { return s.ExecuteStatementWithContext(context.Background(), sql) } +func (s *Session) Ping(ctx context.Context) error { + status, err := s.client.TestConnectionEmptyRPC(ctx) + if err != nil { + return err + } + if status.GetCode() == SuccessStatus { + return nil + } + return errors.New("Ping failed: " + status.GetMessage()) +} + func (s *Session) ExecuteNonQueryStatement(sql string) error { request := rpc.TSExecuteStatementReq{ SessionId: s.sessionId, diff --git a/database/batch.go b/database/batch.go new file mode 100644 index 0000000..aa2c923 --- /dev/null +++ b/database/batch.go @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "context" + "database/sql/driver" + + "github.com/pkg/errors" +) + +type stdBatch struct { + debugf func(format string, v ...any) +} + +func (s *stdBatch) NumInput() int { return -1 } + +func (s *stdBatch) Exec(args []driver.Value) (driver.Result, error) { + return nil, errors.New("not implemented") +} + +func (s *stdBatch) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + return nil, driver.ErrSkip +} + +func (s *stdBatch) Query(args []driver.Value) (driver.Rows, error) { + // Note: not implementing driver.StmtQueryContext accordingly + return nil, errors.New("only Exec method supported in batch mode") +} + +func (s *stdBatch) Close() error { + return nil +} diff --git a/database/bind.go b/database/bind.go new file mode 100644 index 0000000..e608277 --- /dev/null +++ b/database/bind.go @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + std_driver "database/sql/driver" + "fmt" + "reflect" + "regexp" + "strings" + "time" + + "github.com/apache/iotdb-client-go/v2/database/driver" + "github.com/pkg/errors" +) + +var ( + ErrInvalidTimezone = errors.New("invalid timezone value") +) + +func Named(name string, value any) driver.NamedValue { + return driver.NamedValue{ + Name: name, + Value: value, + } +} + +type TimeUnit uint8 + +const ( + Seconds TimeUnit = iota + MilliSeconds + MicroSeconds + NanoSeconds +) + +type GroupSet struct { + Value []any +} + +type ArraySet []any + +func DateNamed(name string, value time.Time, scale TimeUnit) driver.NamedDateValue { + return driver.NamedDateValue{ + Name: name, + Value: value, + Scale: uint8(scale), + } +} + +var ( + bindNumericRe = regexp.MustCompile(`\$[0-9]+`) + bindPositionalRe = regexp.MustCompile(`[^\\][?]`) +) + +func bind(tz *time.Location, query string, args ...any) (string, error) { + if len(args) == 0 { + return query, nil + } + var ( + haveNumeric bool + havePositional bool + ) + + allArgumentsNamed, err := checkAllNamedArguments(args...) + if err != nil { + return "", err + } + + if allArgumentsNamed { + return bindNamed(tz, query, args...) + } + + haveNumeric = bindNumericRe.MatchString(query) + havePositional = bindPositionalRe.MatchString(query) + if haveNumeric && havePositional { + return "", ErrBindMixedParamsFormats + } + if haveNumeric { + return bindNumeric(tz, query, args...) + } + return bindPositional(tz, query, args...) +} + +func checkAllNamedArguments(args ...any) (bool, error) { + var ( + haveNamed bool + haveAnonymous bool + ) + for _, v := range args { + switch v.(type) { + case driver.NamedValue, driver.NamedDateValue: + haveNamed = true + default: + haveAnonymous = true + } + if haveNamed && haveAnonymous { + return haveNamed, ErrBindMixedParamsFormats + } + } + return haveNamed, nil +} + +func bindPositional(tz *time.Location, query string, args ...any) (_ string, err error) { + var ( + lastMatchIndex = -1 // Position of previous match for copying + argIndex = 0 // Index for the argument at current position + buf = make([]byte, 0, len(query)) + unbindCount = 0 // Number of positional arguments that couldn't be matched + ) + + for i := 0; i < len(query); i++ { + // It's fine looping through the query string as bytes, because the (fixed) characters we're looking for + // are in the ASCII range to won't take up more than one byte. + if query[i] == '?' { + if i > 0 && query[i-1] == '\\' { + // Copy all previous index to here characters + buf = append(buf, query[lastMatchIndex+1:i-1]...) + buf = append(buf, '?') + } else { + // Copy all previous index to here characters + buf = append(buf, query[lastMatchIndex+1:i]...) + + // Append the argument value + if argIndex < len(args) { + v := args[argIndex] + if fn, ok := v.(std_driver.Valuer); ok { + if v, err = fn.Value(); err != nil { + return "", err + } + } + + value, err := format(tz, Seconds, v) + if err != nil { + return "", err + } + + buf = append(buf, value...) + argIndex++ + } else { + unbindCount++ + } + } + + lastMatchIndex = i + } + } + + // If there were no replacements, quick return without copying the string + if lastMatchIndex < 0 { + return query, nil + } + + // Append the remainder + buf = append(buf, query[lastMatchIndex+1:]...) + + if unbindCount > 0 { + return "", fmt.Errorf("have no arg for param ? at last %d positions", unbindCount) + } + + return string(buf), nil +} + +func bindNumeric(tz *time.Location, query string, args ...any) (_ string, err error) { + var ( + unbind = make(map[string]struct{}) + params = make(map[string]string) + ) + for i, v := range args { + if fn, ok := v.(std_driver.Valuer); ok { + if v, err = fn.Value(); err != nil { + return "", err + } + } + val, err := format(tz, Seconds, v) + if err != nil { + return "", err + } + params[fmt.Sprintf("$%d", i+1)] = val + } + query = bindNumericRe.ReplaceAllStringFunc(query, func(n string) string { + if _, found := params[n]; !found { + unbind[n] = struct{}{} + return "" + } + return params[n] + }) + for param := range unbind { + return "", fmt.Errorf("have no arg for %s param", param) + } + return query, nil +} + +var bindNamedRe = regexp.MustCompile(`@[a-zA-Z0-9\_]+`) + +func bindNamed(tz *time.Location, query string, args ...any) (_ string, err error) { + var ( + unbind = make(map[string]struct{}) + params = make(map[string]string) + ) + for _, v := range args { + switch v := v.(type) { + case driver.NamedValue: + value := v.Value + if fn, ok := v.Value.(std_driver.Valuer); ok { + if value, err = fn.Value(); err != nil { + return "", err + } + } + val, err := format(tz, Seconds, value) + if err != nil { + return "", err + } + params["@"+v.Name] = val + case driver.NamedDateValue: + val, err := format(tz, TimeUnit(v.Scale), v.Value) + if err != nil { + return "", err + } + params["@"+v.Name] = val + } + } + query = bindNamedRe.ReplaceAllStringFunc(query, func(n string) string { + if _, found := params[n]; !found { + unbind[n] = struct{}{} + return "" + } + return params[n] + }) + for param := range unbind { + return "", fmt.Errorf("have no arg for %q param", param) + } + return query, nil +} + +func formatTime(value time.Time) (string, error) { + str := value.Format(time.DateTime) + return str, nil +} + +var stringQuoteReplacer = strings.NewReplacer(`\`, `\\`, `'`, `\'`) + +func format(tz *time.Location, scale TimeUnit, v any) (string, error) { + quote := func(v string) string { + return "'" + stringQuoteReplacer.Replace(v) + "'" + } + switch v := v.(type) { + case nil: + return "NULL", nil + case string: + return quote(v), nil + case time.Time: + return formatTime(v) + case *time.Time: + if v == nil { + return "NULL", nil + } + return formatTime(*v) + case bool: + if v { + return "1", nil + } + return "0", nil + case GroupSet: + val, err := join(tz, scale, v.Value) + if err != nil { + return "", err + } + return fmt.Sprintf("(%s)", val), nil + case []GroupSet: + val, err := join(tz, scale, v) + if err != nil { + return "", err + } + return val, err + case ArraySet: + val, err := join(tz, scale, v) + if err != nil { + return "", err + } + return fmt.Sprintf("[%s]", val), nil + case fmt.Stringer: + if v := reflect.ValueOf(v); v.Kind() == reflect.Pointer && + v.IsNil() && + v.Type().Elem().Implements(reflect.TypeOf((*fmt.Stringer)(nil)).Elem()) { + return "NULL", nil + } + return quote(v.String()), nil + } + switch v := reflect.ValueOf(v); v.Kind() { + case reflect.String: + return quote(v.String()), nil + case reflect.Slice, reflect.Array: + values := make([]string, 0, v.Len()) + for i := 0; i < v.Len(); i++ { + val, err := format(tz, scale, v.Index(i).Interface()) + if err != nil { + return "", err + } + values = append(values, val) + } + return fmt.Sprintf("[%s]", strings.Join(values, ", ")), nil + case reflect.Map: // map + values := make([]string, 0, len(v.MapKeys())) + for _, key := range v.MapKeys() { + name := fmt.Sprint(key.Interface()) + if key.Kind() == reflect.String { + name = fmt.Sprintf("'%s'", name) + } + val, err := format(tz, scale, v.MapIndex(key).Interface()) + if err != nil { + return "", err + } + values = append(values, fmt.Sprintf("%s, %s", name, val)) + } + return "map(" + strings.Join(values, ", ") + ")", nil + case reflect.Ptr: + if v.IsNil() { + return "NULL", nil + } + return format(tz, scale, v.Elem().Interface()) + } + return fmt.Sprint(v), nil +} + +func join[E any](tz *time.Location, scale TimeUnit, values []E) (string, error) { + items := make([]string, len(values), len(values)) + for i := range values { + val, err := format(tz, scale, values[i]) + if err != nil { + return "", err + } + items[i] = val + } + return strings.Join(items, ", "), nil +} + +func rebind(in []std_driver.NamedValue) []any { + args := make([]any, 0, len(in)) + for _, v := range in { + switch { + case len(v.Name) != 0: + args = append(args, driver.NamedValue{ + Name: v.Name, + Value: v.Value, + }) + + default: + args = append(args, v.Value) + } + } + return args +} diff --git a/database/bind_test.go b/database/bind_test.go new file mode 100644 index 0000000..fd8d446 --- /dev/null +++ b/database/bind_test.go @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "testing" + "time" + + "github.com/apache/iotdb-client-go/v2/database/driver" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ==================== Positional Binding (?) ==================== + +// TestBind_Positional_Basic verifies basic positional placeholder replacement. +func TestBind_Positional_Basic(t *testing.T) { + result, err := bind(nil, "INSERT INTO t(a, b) VALUES (?, ?)", "hello", 42) + require.NoError(t, err) + assert.Equal(t, "INSERT INTO t(a, b) VALUES ('hello', 42)", result) +} + +// TestBind_Positional_MultipleValues verifies multiple positional args in one query. +func TestBind_Positional_MultipleValues(t *testing.T) { + result, err := bind(nil, "SELECT * FROM t WHERE a = ? AND b = ? AND c = ?", "x", 1, true) + require.NoError(t, err) + assert.Equal(t, "SELECT * FROM t WHERE a = 'x' AND b = 1 AND c = 1", result) +} + +// TestBind_Positional_EscapedQuestionMark verifies that \? is not treated as a placeholder. +func TestBind_Positional_EscapedQuestionMark(t *testing.T) { + result, err := bind(nil, `SELECT * FROM t WHERE a = \? AND b = ?`, "val") + require.NoError(t, err) + assert.Equal(t, "SELECT * FROM t WHERE a = ? AND b = 'val'", result) +} + +// TestBind_Positional_TooFewArgs returns an error when args are fewer than placeholders. +func TestBind_Positional_TooFewArgs(t *testing.T) { + _, err := bind(nil, "INSERT INTO t VALUES (?, ?, ?)", "a") + assert.Error(t, err) +} + +// TestBind_Positional_NoArgs returns the query unchanged when no args are provided. +func TestBind_Positional_NoArgs(t *testing.T) { + result, err := bind(nil, "SELECT * FROM t") + require.NoError(t, err) + assert.Equal(t, "SELECT * FROM t", result) +} + +// ==================== Numeric Binding ($N) ==================== + +// TestBind_Numeric_Basic verifies $1, $2 style placeholder replacement. +func TestBind_Numeric_Basic(t *testing.T) { + result, err := bind(nil, "SELECT * FROM t WHERE a = $1 AND b = $2", "foo", 100) + require.NoError(t, err) + assert.Equal(t, "SELECT * FROM t WHERE a = 'foo' AND b = 100", result) +} + +// TestBind_Numeric_OutOfOrder verifies that $N params can be used in any order. +func TestBind_Numeric_OutOfOrder(t *testing.T) { + result, err := bind(nil, "SELECT $2, $1 FROM t", "first", "second") + require.NoError(t, err) + assert.Equal(t, "SELECT 'second', 'first' FROM t", result) +} + +// TestBind_Numeric_Reuse verifies that the same $N can be used multiple times. +func TestBind_Numeric_Reuse(t *testing.T) { + result, err := bind(nil, "SELECT * FROM t WHERE a = $1 OR b = $1", "val") + require.NoError(t, err) + assert.Equal(t, "SELECT * FROM t WHERE a = 'val' OR b = 'val'", result) +} + +// TestBind_Numeric_MissingArg returns an error when a referenced $N has no corresponding arg. +func TestBind_Numeric_MissingArg(t *testing.T) { + _, err := bind(nil, "SELECT * FROM t WHERE a = $1 AND b = $3", "only_one") + assert.Error(t, err) +} + +// ==================== Named Binding (@name) ==================== + +// TestBind_Named_Basic verifies @name style placeholder replacement. +func TestBind_Named_Basic(t *testing.T) { + result, err := bind(nil, "SELECT * FROM t WHERE name = @name AND age = @age", + driver.NamedValue{Name: "name", Value: "Alice"}, + driver.NamedValue{Name: "age", Value: 30}, + ) + require.NoError(t, err) + assert.Equal(t, "SELECT * FROM t WHERE name = 'Alice' AND age = 30", result) +} + +// TestBind_Named_Reuse verifies the same @name can appear multiple times. +func TestBind_Named_Reuse(t *testing.T) { + result, err := bind(nil, "SELECT * FROM t WHERE a = @x OR b = @x", + driver.NamedValue{Name: "x", Value: "val"}, + ) + require.NoError(t, err) + assert.Equal(t, "SELECT * FROM t WHERE a = 'val' OR b = 'val'", result) +} + +// TestBind_Named_MissingParam returns an error when a @param in query has no matching arg. +func TestBind_Named_MissingParam(t *testing.T) { + _, err := bind(nil, "SELECT * FROM t WHERE a = @missing", + driver.NamedValue{Name: "other", Value: "val"}, + ) + assert.Error(t, err) +} + +// TestBind_Named_DateValue verifies NamedDateValue binds time correctly. +func TestBind_Named_DateValue(t *testing.T) { + ts := time.Date(2024, 6, 12, 15, 30, 45, 0, time.UTC) + result, err := bind(nil, "SELECT * FROM t WHERE ts = @ts", + driver.NamedDateValue{Name: "ts", Value: ts, Scale: uint8(Seconds)}, + ) + require.NoError(t, err) + assert.Equal(t, "SELECT * FROM t WHERE ts = 2024-06-12 15:30:45", result) +} + +// ==================== Mixed Format Detection ==================== + +// TestBind_Mixed_NamedAndPositional returns an error when mixing named and anonymous args. +func TestBind_Mixed_NamedAndPositional(t *testing.T) { + _, err := bind(nil, "SELECT * FROM t WHERE a = ? AND b = ?", + driver.NamedValue{Name: "a", Value: "x"}, + "anonymous", + ) + assert.ErrorIs(t, err, ErrBindMixedParamsFormats) +} + +// ==================== Value Formatting ==================== + +// TestFormat_Nil formats nil as NULL. +func TestFormat_Nil(t *testing.T) { + result, err := format(nil, Seconds, nil) + require.NoError(t, err) + assert.Equal(t, "NULL", result) +} + +// TestFormat_String wraps string in single quotes and escapes special chars. +func TestFormat_String(t *testing.T) { + result, err := format(nil, Seconds, "it's a \"test\"") + require.NoError(t, err) + assert.Equal(t, `'it\'s a "test"'`, result) +} + +// TestFormat_StringWithBackslash escapes backslashes. +func TestFormat_StringWithBackslash(t *testing.T) { + result, err := format(nil, Seconds, `path\to\file`) + require.NoError(t, err) + assert.Equal(t, `'path\\to\\file'`, result) +} + +// TestFormat_Bool formats true as 1 and false as 0. +func TestFormat_Bool(t *testing.T) { + trueResult, err := format(nil, Seconds, true) + require.NoError(t, err) + assert.Equal(t, "1", trueResult) + + falseResult, err := format(nil, Seconds, false) + require.NoError(t, err) + assert.Equal(t, "0", falseResult) +} + +// TestFormat_Time formats time.Time as "2006-01-02 15:04:05". +func TestFormat_Time(t *testing.T) { + ts := time.Date(2024, 6, 12, 15, 30, 45, 0, time.UTC) + result, err := format(nil, Seconds, ts) + require.NoError(t, err) + assert.Equal(t, "2024-06-12 15:30:45", result) +} + +// TestFormat_NilPointerTime formats nil *time.Time as NULL. +func TestFormat_NilPointerTime(t *testing.T) { + var ts *time.Time + result, err := format(nil, Seconds, ts) + require.NoError(t, err) + assert.Equal(t, "NULL", result) +} + +// TestFormat_Integer formats integers without quotes. +func TestFormat_Integer(t *testing.T) { + result, err := format(nil, Seconds, 42) + require.NoError(t, err) + assert.Equal(t, "42", result) +} + +// TestFormat_Float formats floats without quotes. +func TestFormat_Float(t *testing.T) { + result, err := format(nil, Seconds, 3.14) + require.NoError(t, err) + assert.Contains(t, result, "3.14") +} + +// TestFormat_GroupSet formats GroupSet as parenthesized comma-separated values. +func TestFormat_GroupSet(t *testing.T) { + result, err := format(nil, Seconds, GroupSet{Value: []any{"a", "b", "c"}}) + require.NoError(t, err) + assert.Equal(t, "('a', 'b', 'c')", result) +} + +// TestFormat_ArraySet formats ArraySet as bracketed comma-separated values. +func TestFormat_ArraySet(t *testing.T) { + result, err := format(nil, Seconds, ArraySet{"x", "y"}) + require.NoError(t, err) + assert.Equal(t, "['x', 'y']", result) +} + +// TestFormat_NilPointer formats a nil pointer as NULL. +func TestFormat_NilPointer(t *testing.T) { + var p *int + result, err := format(nil, Seconds, p) + require.NoError(t, err) + assert.Equal(t, "NULL", result) +} + +// TestFormat_Slice formats a slice as bracketed comma-separated values. +func TestFormat_Slice(t *testing.T) { + result, err := format(nil, Seconds, []int{1, 2, 3}) + require.NoError(t, err) + assert.Equal(t, "[1, 2, 3]", result) +} diff --git a/database/column/blob.go b/database/column/blob.go new file mode 100644 index 0000000..169a5b8 --- /dev/null +++ b/database/column/blob.go @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package column + +import "github.com/apache/iotdb-client-go/v2/client" + +type Blob struct { + name string +} + +func (b *Blob) Name() string { + return b.name +} +func (b *Blob) Type() Type { + return "BLOB" +} + +func (b *Blob) Row(stat *client.SessionDataSet, ptr bool) any { + if stat == nil { + if ptr { + return nil + } + return 0 + } + value, err := stat.GetBlob(b.name) + if err != nil { + if ptr { + return nil + } + return 0 + } + if ptr { + return &value + } + return value +} diff --git a/database/column/bool.go b/database/column/bool.go new file mode 100644 index 0000000..d9c7dd2 --- /dev/null +++ b/database/column/bool.go @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package column + +import ( + "github.com/apache/iotdb-client-go/v2/client" +) + +type Bool struct { + name string +} + +func (b *Bool) Name() string { + return b.name +} +func (b *Bool) Type() Type { + return "BOOLEAN" +} +func (b *Bool) Row(stat *client.SessionDataSet, ptr bool) any { + if stat == nil { + if ptr { + return nil + } + return 0 + } + value, err := stat.GetBoolean(b.name) + if err != nil { + if ptr { + return nil + } + return 0 + } + if ptr { + return &value + } + return value +} diff --git a/database/column/column.go b/database/column/column.go new file mode 100644 index 0000000..802246b --- /dev/null +++ b/database/column/column.go @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package column + +import ( + "github.com/apache/iotdb-client-go/v2/client" +) + +type Type string + +type Interface interface { + Name() string + Type() Type + Row(stat *client.SessionDataSet, ptr bool) any +} + +func GenColumn(dataType string, name string) Interface { + switch dataType { + case "BOOLEAN": + return &Bool{ + name: name, + } + case "INT32": + return &Int32{ + name: name, + } + case "INT64": + { + return &Int64{ + name: name, + } + } + case "FLOAT": + return &Float{ + name: name, + } + case "DOUBLE": + return &Double{ + name: name, + } + case "TEXT": + { + return &String{ + name: name, + } + } + case "TIMESTAMP": + return &Timestamp{ + name: name, + } + case "DATE": + return &Date{ + name: name, + } + case "BLOB": + return &Blob{ + name: name, + } + case "STRING": + return &String{ + name: name, + } + } + return nil +} diff --git a/database/column/date.go b/database/column/date.go new file mode 100644 index 0000000..6f0e66b --- /dev/null +++ b/database/column/date.go @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package column + +import "github.com/apache/iotdb-client-go/v2/client" + +type Date struct { + name string +} + +func (b *Date) Name() string { + return b.name +} +func (b *Date) Type() Type { + return "DATE" +} + +func (b *Date) Row(stat *client.SessionDataSet, ptr bool) any { + if stat == nil { + if ptr { + return nil + } + return 0 + } + value, err := stat.GetDate(b.name) + if err != nil { + if ptr { + return nil + } + return 0 + } + if ptr { + return &value + } + return value +} diff --git a/database/column/double.go b/database/column/double.go new file mode 100644 index 0000000..df933b9 --- /dev/null +++ b/database/column/double.go @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package column + +import ( + "github.com/apache/iotdb-client-go/v2/client" +) + +type Double struct { + name string +} + +func (d *Double) Name() string { + return d.name +} +func (d *Double) Type() Type { + return "DOUBLE" +} +func (d *Double) Row(stat *client.SessionDataSet, ptr bool) any { + if stat == nil { + if ptr { + return nil + } + return 0 + } + value, err := stat.GetDouble(d.name) + if err != nil { + if ptr { + return nil + } + return 0 + } + if ptr { + return &value + } + return value +} diff --git a/database/column/float.go b/database/column/float.go new file mode 100644 index 0000000..a2e1b72 --- /dev/null +++ b/database/column/float.go @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package column + +import ( + "github.com/apache/iotdb-client-go/v2/client" +) + +type Float struct { + name string +} + +func (f *Float) Name() string { + return f.name +} +func (f *Float) Type() Type { + return "FLOAT" +} + +func (f *Float) Row(stat *client.SessionDataSet, ptr bool) any { + if stat == nil { + if ptr { + return nil + } + return 0 + } + value, err := stat.GetFloat(f.name) + if err != nil { + if ptr { + return nil + } + return 0 + } + if ptr { + return &value + } + return value +} diff --git a/database/column/int32.go b/database/column/int32.go new file mode 100644 index 0000000..f44b531 --- /dev/null +++ b/database/column/int32.go @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package column + +import "github.com/apache/iotdb-client-go/v2/client" + +type Int32 struct { + name string +} + +func (i *Int32) Name() string { + return i.name +} +func (i *Int32) Type() Type { + return "INT32" +} +func (i *Int32) Row(stat *client.SessionDataSet, ptr bool) any { + if stat == nil { + if ptr { + return nil + } + return 0 + } + value, err := stat.GetInt(i.name) + if err != nil { + if ptr { + return nil + } + return 0 + } + if ptr { + return &value + } + return value +} diff --git a/database/column/int64.go b/database/column/int64.go new file mode 100644 index 0000000..68e9bd0 --- /dev/null +++ b/database/column/int64.go @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package column + +import ( + "reflect" + + "github.com/apache/iotdb-client-go/v2/client" +) + +type Int64 struct { + name string +} + +func (i *Int64) Name() string { + return i.name +} +func (i *Int64) Type() Type { + return "INT64" +} +func (i *Int64) Rows() int { + return 0 +} +func (i *Int64) Row(stat *client.SessionDataSet, ptr bool) any { + if stat == nil { + if ptr { + return nil + } + return 0 + } + value, err := stat.GetLong(i.name) + if err != nil { + if ptr { + return nil + } + return 0 + } + if ptr { + return &value + } + return value +} +func (i *Int64) ScanRow(dest any, row int) error { + return nil +} +func (i *Int64) Append(v any) (nulls []uint8, err error) { + return nil, nil +} +func (i *Int64) AppendRow(v any) error { + return nil +} +func (i *Int64) ScanType() reflect.Type { + return reflect.TypeOf(int64(0)) +} +func (i *Int64) Reset() { + return +} diff --git a/database/column/string.go b/database/column/string.go new file mode 100644 index 0000000..ba52d7e --- /dev/null +++ b/database/column/string.go @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package column + +import ( + "github.com/apache/iotdb-client-go/v2/client" +) + +type String struct { + name string +} + +func (s *String) Name() string { + return s.name +} +func (s *String) Type() Type { + return "STRING" +} +func (s *String) Row(stat *client.SessionDataSet, ptr bool) any { + if stat == nil { + if ptr { + return nil + } + return "" + } + getString, err := stat.GetString(s.name) + if err != nil { + if ptr { + return nil + } + return "" + } + if ptr { + return &getString + } + return getString +} + +var _ Interface = (*String)(nil) diff --git a/database/column/timestamp.go b/database/column/timestamp.go new file mode 100644 index 0000000..a79a0c2 --- /dev/null +++ b/database/column/timestamp.go @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package column + +import ( + "time" + + "github.com/apache/iotdb-client-go/v2/client" +) + +type Timestamp struct { + name string +} + +func (t *Timestamp) Name() string { + return t.name +} +func (t *Timestamp) Type() Type { + return "TIMESTAMP" +} + +func (t *Timestamp) Row(stat *client.SessionDataSet, ptr bool) any { + if stat == nil { + if ptr { + return nil + } + return time.Time{} + } + value, err := stat.GetTimestamp(t.name) + if err != nil { + if ptr { + return nil + } + return time.Time{} + } + if ptr { + return &value + } + return value +} diff --git a/database/conn.go b/database/conn.go new file mode 100644 index 0000000..46a0172 --- /dev/null +++ b/database/conn.go @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "context" + "fmt" + "log" + "net" + "os" + "time" + + "github.com/apache/iotdb-client-go/v2/client" + "github.com/apache/iotdb-client-go/v2/database/column" + "github.com/pkg/errors" +) + +func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, error) { + if addr == "" { + return nil, errors.New("empty addr") + } + // 使用 net.SplitHostPort 分割地址和端口 + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + var ( + conn client.SessionPool + debugf = func(format string, v ...any) {} + ) + + if opt.Debug { + if opt.Debugf != nil { + debugf = func(format string, v ...any) { + opt.Debugf( + "[iotdb][%s][id=%d] "+format, + append([]interface{}{opt.Addr, num}, v...)..., + ) + } + } else { + debugf = log.New(os.Stdout, fmt.Sprintf("[iotdb][%s][id=%d]", opt.Addr, num), 0).Printf + } + } + var ( + config = &client.PoolConfig{ + Host: host, + Port: port, + UserName: opt.UserName, + Password: opt.Password, + } + poolMaxSize = 3 + poolWaitToGetSessionTimeoutInMs = 60000 + poolConnectionTimeoutInMs = 60000 + poolEnableCompression = false + ) + if opt.PoolMaxSize != nil { + poolMaxSize = *opt.PoolMaxSize + } + if opt.PoolWaitToGetSessionTimeoutInMs != nil { + poolWaitToGetSessionTimeoutInMs = *opt.PoolWaitToGetSessionTimeoutInMs + } + if opt.PoolConnectionTimeoutInMs != nil { + poolConnectionTimeoutInMs = *opt.PoolConnectionTimeoutInMs + } + if opt.PoolEnableCompression != nil { + poolEnableCompression = *opt.PoolEnableCompression + } + conn = client.NewSessionPool(config, poolMaxSize, poolConnectionTimeoutInMs, poolWaitToGetSessionTimeoutInMs, poolEnableCompression) + + var ( + netConn = &connect{ + id: num, + opt: opt, + conn: conn, + debugfFunc: debugf, + connectedAt: time.Now(), + } + ) + + return netConn, nil +} + +type connect struct { + id int + opt *Options + conn client.SessionPool + debugfFunc func(format string, v ...any) + connectedAt time.Time + timeZone *time.Location +} + +func (c *connect) debugf(format string, v ...any) { + c.debugfFunc(format, v...) +} + +func (c *connect) isBad() bool { + return false +} +func (c *connect) close() error { + c.conn.Close() + return nil +} + +func (c *connect) ping(ctx context.Context) (err error) { + session, err := c.conn.GetSession() + if err != nil { + return err + } + defer c.conn.PutBack(session) + return session.Ping(ctx) +} + +func (c *connect) query(ctx context.Context, release nativeTransportRelease, query string, args ...any) (*rows, error) { + options := queryOptions(ctx) + body, err := bindQueryOrAppendParameters(&options, query, c.timeZone, args...) + if err != nil { + return nil, err + } + session, err := c.conn.GetSession() + if err != nil { + release(c, err) + return nil, err + } + defer c.conn.PutBack(session) + var timeout int64 = int64(c.opt.DialTimeout.Seconds() * 1000) + if timeout == 0 { + timeout = 5000 + } + statement, err := session.ExecuteQueryStatement(body, &timeout) + if err != nil { + release(c, err) + return nil, err + } + + // column list + names := statement.GetColumnNames() + columnsList := make([]column.Interface, len(names)) + for k, name := range names { + dataType := statement.GetColumnTypes()[k] + col := column.GenColumn(dataType, name) + if col == nil { + continue + } + columnsList[k] = col + } + return &rows{ + set: statement, + columns: columnsList, + }, nil +} + +func (c *connect) commit() error { + return nil +} diff --git a/database/conn_exec.go b/database/conn_exec.go new file mode 100644 index 0000000..9288086 --- /dev/null +++ b/database/conn_exec.go @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "context" +) + +func (c *connect) exec(ctx context.Context, query string, args ...any) error { + var ( + options = queryOptions(ctx) + body, err = bindQueryOrAppendParameters(&options, query, c.timeZone, args...) + ) + if err != nil { + return err + } + + session, err := c.conn.GetSession() + if err != nil { + return err + } + defer c.conn.PutBack(session) + + _, err = session.ExecuteStatement(body) + if err != nil { + return err + } + return nil +} diff --git a/database/conn_integration_test.go b/database/conn_integration_test.go new file mode 100644 index 0000000..70de961 --- /dev/null +++ b/database/conn_integration_test.go @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ==================== Connection Lifecycle ==================== + +// TestConnect_Ping verifies a basic ping to the server succeeds. +func TestConnect_Ping(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db, err := sql.Open("iotdb", "iotdb://root:root@127.0.0.1:6667") + require.NoError(t, err) + defer db.Close() + + err = db.PingContext(context.Background()) + require.NoError(t, err) +} + +// TestConnect_PingTimeout verifies ping succeeds within a context deadline. +func TestConnect_PingTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db, err := sql.Open("iotdb", "iotdb://root:root@127.0.0.1:6667") + require.NoError(t, err) + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err = db.PingContext(ctx) + require.NoError(t, err) +} + +// TestConnect_PingUnreachable verifies that ping to an unreachable host returns an error. +func TestConnect_PingUnreachable(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db, err := sql.Open("iotdb", "iotdb://root:root@127.0.0.1:19999") + require.NoError(t, err) + defer db.Close() + + err = db.PingContext(context.Background()) + require.Error(t, err) +} + +// ==================== Connection Pool ==================== + +// TestConnect_ConnectionPool verifies pool settings are applied and connections are created. +func TestConnect_ConnectionPool(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db, err := sql.Open("iotdb", "iotdb://root:root@127.0.0.1:6667") + require.NoError(t, err) + defer db.Close() + + db.SetMaxOpenConns(10) + db.SetMaxIdleConns(5) + db.SetConnMaxLifetime(10 * time.Minute) + + assert.Equal(t, 10, db.Stats().MaxOpenConnections) + + // Acquire a connection to force pool creation + conn, err := db.Conn(context.Background()) + require.NoError(t, err) + defer conn.Close() + + assert.GreaterOrEqual(t, db.Stats().OpenConnections, 1) +} + +// ==================== Close ==================== + +// TestConnect_Close verifies that closing a db handle succeeds cleanly. +func TestConnect_Close(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db, err := sql.Open("iotdb", "iotdb://root:root@127.0.0.1:6667") + require.NoError(t, err) + + // Ping first to establish a real connection + err = db.PingContext(context.Background()) + require.NoError(t, err) + + err = db.Close() + require.NoError(t, err) +} + +// ==================== Exec + Query Round Trip ==================== + +// TestConnect_QueryBasic verifies a full write-then-read cycle through the driver. +func TestConnect_QueryBasic(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db, err := sql.Open("iotdb", "iotdb://root:root@127.0.0.1:6667") + require.NoError(t, err) + defer db.Close() + + sg := "root.conn_test" + _, err = db.ExecContext(context.Background(), "CREATE DATABASE "+sg) + require.NoError(t, err) + defer func() { _, _ = db.ExecContext(context.Background(), "DELETE DATABASE "+sg) }() + + _, err = db.ExecContext(context.Background(), + "CREATE TIMESERIES "+sg+".d1.value WITH DATATYPE=FLOAT, ENCODING=PLAIN") + require.NoError(t, err) + + _, err = db.ExecContext(context.Background(), + "INSERT INTO "+sg+".d1(timestamp, value) VALUES (?, ?)", 1, 123.45) + require.NoError(t, err) + + rows, err := db.QueryContext(context.Background(), "SELECT value FROM "+sg+".d1") + require.NoError(t, err) + defer rows.Close() + + require.True(t, rows.Next()) + var ts int64 + var value float64 + err = rows.Scan(&ts, &value) + require.NoError(t, err) + assert.InDelta(t, 123.45, value, 0.01) + require.NoError(t, rows.Err()) +} + +// TestConnect_ShowDatabases verifies metadata queries work through a fresh connection. +func TestConnect_ShowDatabases(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db, err := sql.Open("iotdb", "iotdb://root:root@127.0.0.1:6667") + require.NoError(t, err) + defer db.Close() + + rows, err := db.QueryContext(context.Background(), "SHOW DATABASES") + require.NoError(t, err) + defer rows.Close() + + var count int + for rows.Next() { + count++ + } + assert.GreaterOrEqual(t, count, 1) + require.NoError(t, rows.Err()) +} diff --git a/database/context.go b/database/context.go new file mode 100644 index 0000000..32385e9 --- /dev/null +++ b/database/context.go @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "context" + "time" +) + +type Parameters map[string]string + +type Settings map[string]any + +// ColumnNameAndType represents a column name and type +type ColumnNameAndType struct { + Name string + Type string +} + +var _contextOptionKey = &QueryOptions{ + settings: Settings{ + "_contextOption": struct{}{}, + }, +} + +type ( + QueryOption func(*QueryOptions) error + AsyncOptions struct { + ok bool + wait bool + } + QueryOptions struct { + async AsyncOptions + queryID string + quotaKey string + jwt string + + settings Settings + parameters Parameters + blockBufferSize uint8 + userLocation *time.Location + columnNamesAndTypes []ColumnNameAndType + clientInfo ClientInfo + } +) + +// clone returns a copy of QueryOptions where Settings and Parameters are safely mutable. +func (q *QueryOptions) clone() QueryOptions { + c := QueryOptions{ + async: q.async, + queryID: q.queryID, + quotaKey: q.quotaKey, + settings: nil, + parameters: nil, + blockBufferSize: q.blockBufferSize, + userLocation: q.userLocation, + columnNamesAndTypes: nil, + } + + if q.settings != nil { + c.settings = make(Settings, len(q.settings)) + for k, v := range q.settings { + c.settings[k] = v + } + } + + if q.parameters != nil { + c.parameters = make(Parameters, len(q.parameters)) + for k, v := range q.parameters { + c.parameters[k] = v + } + } + + if q.columnNamesAndTypes != nil { + c.columnNamesAndTypes = make([]ColumnNameAndType, len(q.columnNamesAndTypes)) + copy(c.columnNamesAndTypes, q.columnNamesAndTypes) + } + + if q.clientInfo.Products != nil || q.clientInfo.Comment != nil { + c.clientInfo = q.clientInfo.Append(ClientInfo{}) + } + + return c +} + +// queryOptions returns a mutable copy of the QueryOptions struct within the given context. +// If iotdb context was not provided, an empty struct with a valid Settings map is returned. +// If the context has a deadline greater than 1s then max_execution_time setting is appended. +func queryOptions(ctx context.Context) QueryOptions { + var opt QueryOptions + + if ctxOpt, ok := ctx.Value(_contextOptionKey).(QueryOptions); ok { + opt = ctxOpt.clone() + } else { + opt = QueryOptions{ + settings: make(Settings), + } + } + + deadline, ok := ctx.Deadline() + if !ok { + return opt + } + + if sec := time.Until(deadline).Seconds(); sec > 1 { + opt.settings["max_execution_time"] = int(sec + 5) + } + + return opt +} diff --git a/database/driver/driver.go b/database/driver/driver.go new file mode 100644 index 0000000..2b41b27 --- /dev/null +++ b/database/driver/driver.go @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package driver + +import "time" + +type ( + NamedValue struct { + Name string + Value any + } + + NamedDateValue struct { + Name string + Value time.Time + Scale uint8 + } +) diff --git a/database/iotdb.go b/database/iotdb.go new file mode 100644 index 0000000..4c7bffd --- /dev/null +++ b/database/iotdb.go @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "context" + "errors" +) + +// nativeTransport represents an implementation (TCP or HTTP) that can be pooled by the main iotdb struct. +// Implementations are not expected to be thread safe, which is why we provide acquire/release functions. +type nativeTransport interface { + query(ctx context.Context, release nativeTransportRelease, query string, args ...any) (*rows, error) +} + +// nativeTransport represents an implementation (TCP or HTTP) that can be pooled by the main iotdb struct. +// Implementations are not expected to be thread safe, which is why we provide acquire/release functions. +type nativeTransportAcquire func(context.Context) (nativeTransport, error) +type nativeTransportRelease func(nativeTransport, error) + +var ( + ErrBatchInvalid = errors.New("iotdb: batch is invalid. check appended data is correct") + ErrBatchAlreadySent = errors.New("iotdb: batch has already been sent") + ErrBatchNotSent = errors.New("iotdb: invalid retry, batch not sent yet") + ErrAcquireConnTimeout = errors.New("iotdb: acquire conn timeout. you can increase the number of max open conn or the dial timeout") + ErrUnsupportedServerRevision = errors.New("iotdb: unsupported server revision") + ErrBindMixedParamsFormats = errors.New("iotdb [bind]: mixed named, numeric or positional parameters") + ErrAcquireConnNoAddress = errors.New("iotdb: no valid address supplied") + ErrServerUnexpectedData = errors.New("code: 101, message: Unexpected packet Data received from client") + ErrConnectionClosed = errors.New("iotdb: connection is closed") +) diff --git a/database/iotdb_std.go b/database/iotdb_std.go new file mode 100644 index 0000000..406685a --- /dev/null +++ b/database/iotdb_std.go @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "io" + "log" + "net" + "os" + "syscall" +) + +var globalConnID int64 + +func init() { + var debugf = func(format string, v ...any) {} + sql.Register("iotdb", &stdDriver{debugf: debugf}) +} + +// isConnBrokenError returns true if the error class indicates that the +// db connection is no longer usable and should be marked bad +func isConnBrokenError(err error) bool { + if errors.Is(err, io.EOF) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { + return true + } + if _, ok := err.(*net.OpError); ok { + return true + } + return false +} + +type stdDriver struct { + opt *Options + conn stdConnect + commit func() error + debugf func(format string, v ...any) +} + +var _ driver.Conn = (*stdDriver)(nil) +var _ driver.ConnBeginTx = (*stdDriver)(nil) +var _ driver.ExecerContext = (*stdDriver)(nil) +var _ driver.QueryerContext = (*stdDriver)(nil) +var _ driver.ConnPrepareContext = (*stdDriver)(nil) + +func (std *stdDriver) Open(dsn string) (_ driver.Conn, err error) { + var opt Options + if err := opt.fromDSN(dsn); err != nil { + std.debugf("Open dsn error: %v\n", err) + return nil, err + } + var debugf = func(format string, v ...any) {} + if opt.Debug { + debugf = log.New(os.Stdout, "[iotdb-std][opener] ", 0).Printf + } + opt.ClientInfo.Comment = []string{"database/sql"} + return (&stdConnOpener{opt: &opt, debugf: debugf}).Connect(context.Background()) +} + +var _ driver.Driver = (*stdDriver)(nil) + +func (std *stdDriver) ResetSession(ctx context.Context) error { + if std.conn.isBad() { + std.debugf("Resetting session because connection is bad") + return driver.ErrBadConn + } + return nil +} + +var _ driver.SessionResetter = (*stdDriver)(nil) + +func (std *stdDriver) Ping(ctx context.Context) error { + if std.conn.isBad() { + std.debugf("Ping: connection is bad") + return driver.ErrBadConn + } + + return std.conn.ping(ctx) +} + +var _ driver.Pinger = (*stdDriver)(nil) + +// Begin starts and returns a new transaction. +// +// Deprecated: Drivers should implement ConnBeginTx instead (or additionally). +func (std *stdDriver) Begin() (driver.Tx, error) { + if std.conn.isBad() { + std.debugf("Begin: connection is bad") + return nil, driver.ErrBadConn + } + return std, nil +} + +func (std *stdDriver) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if std.conn.isBad() { + std.debugf("BeginTx: connection is bad") + return nil, driver.ErrBadConn + } + + return std, nil +} + +func (std *stdDriver) Commit() error { + if std.commit == nil { + return nil + } + defer func() { + std.commit = nil + }() + + if err := std.commit(); err != nil { + if isConnBrokenError(err) { + std.debugf("Commit got EOF error: resetting connection") + return driver.ErrBadConn + } + std.debugf("Commit error: %v\n", err) + return err + } + return nil +} + +func (std *stdDriver) Rollback() error { + std.commit = nil + //std.conn.close() + return nil +} + +var _ driver.Tx = (*stdDriver)(nil) + +func (std *stdDriver) CheckNamedValue(nv *driver.NamedValue) error { return nil } + +var _ driver.NamedValueChecker = (*stdDriver)(nil) + +// Prepare returns a prepared statement, bound to this connection. +func (std *stdDriver) Prepare(query string) (driver.Stmt, error) { + return std.PrepareContext(context.Background(), query) +} + +func (std *stdDriver) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + if std.conn.isBad() { + std.debugf("QueryContext: connection is bad") + return nil, driver.ErrBadConn + } + + r, err := std.conn.query(ctx, func(nativeTransport, error) {}, query, rebind(args)...) + if isConnBrokenError(err) { + std.debugf("QueryContext got a fatal error, resetting connection: %v\n", err) + return nil, driver.ErrBadConn + } + if err != nil { + std.debugf("QueryContext error: %v\n", err) + return nil, err + } + return &stdRows{ + rows: r, + debugf: std.debugf, + }, nil +} + +func (std *stdDriver) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if std.conn.isBad() { + std.debugf("PrepareContext: connection is bad") + return nil, driver.ErrBadConn + } + + std.commit = std.conn.commit + return &stdBatch{ + debugf: std.debugf, + }, nil +} + +// Close invalidates and potentially stops any current +// prepared statements and transactions, marking this +// connection as no longer in use. +// +// Because the sql package maintains a free pool of +// connections and only calls Close when there's a surplus of +// idle connections, it shouldn't be necessary for drivers to +// do their own connection caching. +// +// Drivers must ensure all network calls made by Close +// do not block indefinitely (e.g. apply a timeout). +func (std *stdDriver) Close() error { + err := std.conn.close() + if err != nil { + if isConnBrokenError(err) { + std.debugf("Close got a fatal error, resetting connection: %v\n", err) + return driver.ErrBadConn + } + std.debugf("Close error: %v\n", err) + } + return err +} + +func (std *stdDriver) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + if std.conn.isBad() { + std.debugf("ExecContext: connection is bad") + return nil, driver.ErrBadConn + } + err := std.conn.exec(ctx, query, rebind(args)...) + if isConnBrokenError(err) { + std.debugf("ExecContext got a fatal error, resetting connection: %v\n", err) + return nil, driver.ErrBadConn + } + if err != nil { + std.debugf("ExecContext error: %v\n", err) + return nil, err + } + return driver.RowsAffected(0), nil +} diff --git a/database/iotdb_std_test.go b/database/iotdb_std_test.go new file mode 100644 index 0000000..050ddef --- /dev/null +++ b/database/iotdb_std_test.go @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "context" + "database/sql" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var stdConn *sql.DB + +func init() { + var err error + stdConn, err = sql.Open("iotdb", "iotdb://root:root@127.0.0.1:6667") + if err != nil { + panic(err) + } +} + +// ==================== INSERT Operations ==================== + +// TestStd_BasicInsert verifies a single-row INSERT with mixed types. +func TestStd_BasicInsert(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + sg := "root.std_test" + _, err := stdConn.ExecContext(context.Background(), "CREATE DATABASE "+sg) + require.NoError(t, err) + defer func() { _, _ = stdConn.ExecContext(context.Background(), "DELETE DATABASE "+sg) }() + + _, err = stdConn.ExecContext(context.Background(), + "CREATE TIMESERIES "+sg+".d1.status WITH DATATYPE=BOOLEAN, ENCODING=PLAIN") + require.NoError(t, err) + + _, err = stdConn.ExecContext(context.Background(), + "CREATE TIMESERIES "+sg+".d1.temp WITH DATATYPE=FLOAT, ENCODING=PLAIN") + require.NoError(t, err) + + _, err = stdConn.ExecContext(context.Background(), + "INSERT INTO "+sg+".d1(timestamp, status, temp) VALUES (?, ?, ?)", + 1, true, 25.5) + require.NoError(t, err) + + // Verify inserted data can be read back + rows, err := stdConn.QueryContext(context.Background(), "SELECT status, temp FROM "+sg+".d1") + require.NoError(t, err) + defer rows.Close() + + require.True(t, rows.Next()) + var ts int64 + var status bool + var temp float64 + err = rows.Scan(&ts, &status, &temp) + require.NoError(t, err) + assert.True(t, status) + assert.InDelta(t, 25.5, temp, 0.1) + require.NoError(t, rows.Err()) +} + +// TestStd_MultiValueInsert verifies inserting multiple rows in a single statement. +func TestStd_MultiValueInsert(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + sg := "root.std_multi" + _, err := stdConn.ExecContext(context.Background(), "CREATE DATABASE "+sg) + require.NoError(t, err) + defer func() { _, _ = stdConn.ExecContext(context.Background(), "DELETE DATABASE "+sg) }() + + _, err = stdConn.ExecContext(context.Background(), + "CREATE TIMESERIES "+sg+".d1.status WITH DATATYPE=BOOLEAN, ENCODING=PLAIN") + require.NoError(t, err) + + _, err = stdConn.ExecContext(context.Background(), + "CREATE TIMESERIES "+sg+".d1.temp WITH DATATYPE=FLOAT, ENCODING=PLAIN") + require.NoError(t, err) + + _, err = stdConn.ExecContext(context.Background(), + "INSERT INTO "+sg+".d1(timestamp, status, temp) VALUES (?, ?, ?), (?, ?, ?)", + 1, true, 25.0, + 2, false, 26.0) + require.NoError(t, err) + + // Verify both rows exist + rows, err := stdConn.QueryContext(context.Background(), "SELECT temp FROM "+sg+".d1 ORDER BY time ASC") + require.NoError(t, err) + defer rows.Close() + + var count int + for rows.Next() { + count++ + } + assert.Equal(t, 2, count) + require.NoError(t, rows.Err()) +} + +// TestStd_InsertTypes verifies INSERT with INT64 and FLOAT types. +func TestStd_InsertTypes(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + sg := "root.std_types" + _, err := stdConn.ExecContext(context.Background(), "CREATE DATABASE "+sg) + require.NoError(t, err) + defer func() { _, _ = stdConn.ExecContext(context.Background(), "DELETE DATABASE "+sg) }() + + _, err = stdConn.ExecContext(context.Background(), + "CREATE TIMESERIES "+sg+".d1.int_val WITH DATATYPE=INT64, ENCODING=PLAIN") + require.NoError(t, err) + + _, err = stdConn.ExecContext(context.Background(), + "CREATE TIMESERIES "+sg+".d1.float_val WITH DATATYPE=FLOAT, ENCODING=PLAIN") + require.NoError(t, err) + + _, err = stdConn.ExecContext(context.Background(), + "INSERT INTO "+sg+".d1(timestamp, int_val, float_val) VALUES (?, ?, ?)", + 1, 1718193600000, 99.9) + require.NoError(t, err) + + // Verify values + row := stdConn.QueryRowContext(context.Background(), + "SELECT int_val, float_val FROM "+sg+".d1") + var ts int64 + var intVal int64 + var floatVal float64 + err = row.Scan(&ts, &intVal, &floatVal) + require.NoError(t, err) + assert.Equal(t, int64(1718193600000), intVal) + assert.InDelta(t, 99.9, floatVal, 0.1) +} + +// ==================== Basic Query ==================== + +// TestStd_QueryBasic verifies a round-trip INSERT then SELECT. +func TestStd_QueryBasic(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + sg := "root.std_query" + _, err := stdConn.ExecContext(context.Background(), "CREATE DATABASE "+sg) + require.NoError(t, err) + defer func() { _, _ = stdConn.ExecContext(context.Background(), "DELETE DATABASE "+sg) }() + + _, err = stdConn.ExecContext(context.Background(), + "CREATE TIMESERIES "+sg+".d1.temp WITH DATATYPE=FLOAT, ENCODING=PLAIN") + require.NoError(t, err) + + _, err = stdConn.ExecContext(context.Background(), + "INSERT INTO "+sg+".d1(timestamp, temp) VALUES (?, ?)", + 1704067200000, 25.5) + require.NoError(t, err) + + rows, err := stdConn.QueryContext(context.Background(), "SELECT temp FROM "+sg+".d1") + require.NoError(t, err) + defer rows.Close() + + require.True(t, rows.Next()) + var ts int64 + var temp float64 + err = rows.Scan(&ts, &temp) + require.NoError(t, err) + assert.InDelta(t, 25.5, temp, 0.1) + require.NoError(t, rows.Err()) +} + +// ==================== Struct Scanning ==================== + +// TestStd_ScanStruct verifies scanning multiple columns into struct fields. +func TestStd_ScanStruct(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + sg := "root.std_scan" + _, err := stdConn.ExecContext(context.Background(), "CREATE DATABASE "+sg) + require.NoError(t, err) + defer func() { _, _ = stdConn.ExecContext(context.Background(), "DELETE DATABASE "+sg) }() + + _, err = stdConn.ExecContext(context.Background(), + "CREATE TIMESERIES "+sg+".d1.temp WITH DATATYPE=FLOAT, ENCODING=PLAIN") + require.NoError(t, err) + + _, err = stdConn.ExecContext(context.Background(), + "CREATE TIMESERIES "+sg+".d1.hum WITH DATATYPE=DOUBLE, ENCODING=PLAIN") + require.NoError(t, err) + + _, err = stdConn.ExecContext(context.Background(), + "INSERT INTO "+sg+".d1(timestamp, temp, hum) VALUES (?, ?, ?)", + 1704067200000, 25.5, 60.5) + require.NoError(t, err) + + rows, err := stdConn.QueryContext(context.Background(), "SELECT temp, hum FROM "+sg+".d1") + require.NoError(t, err) + defer rows.Close() + + require.True(t, rows.Next()) + + type Reading struct { + Time int64 + Temp float64 + Hum float64 + } + var r Reading + err = rows.Scan(&r.Time, &r.Temp, &r.Hum) + require.NoError(t, err) + assert.InDelta(t, 25.5, r.Temp, 0.1) + assert.InDelta(t, 60.5, r.Hum, 0.1) + require.NoError(t, rows.Err()) +} + +// ==================== Metadata Queries ==================== + +// TestStd_ShowDatabases verifies SHOW DATABASES returns at least one result. +func TestStd_ShowDatabases(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + rows, err := stdConn.QueryContext(context.Background(), "SHOW DATABASES") + require.NoError(t, err) + defer rows.Close() + + var count int + for rows.Next() { + count++ + } + assert.GreaterOrEqual(t, count, 1) + require.NoError(t, rows.Err()) +} diff --git a/database/options.go b/database/options.go new file mode 100644 index 0000000..973d10f --- /dev/null +++ b/database/options.go @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "context" + "crypto/tls" + "net" + "net/url" + "strconv" + "strings" + "time" + + "github.com/apache/iotdb-client-go/v2/client" + "github.com/pkg/errors" +) + +type ClientInfo struct { + Products []struct { + Name string + Version string + } + + Comment []string +} + +// Append returns a new copy of the combined ClientInfo structs +func (a ClientInfo) Append(b ClientInfo) ClientInfo { + c := ClientInfo{ + Products: make([]struct { + Name string + Version string + }, 0, len(a.Products)+len(b.Products)), + Comment: make([]string, 0, len(a.Comment)+len(b.Comment)), + } + + for _, p := range a.Products { + c.Products = append(c.Products, p) + } + for _, p := range b.Products { + c.Products = append(c.Products, p) + } + + for _, cm := range a.Comment { + c.Comment = append(c.Comment, cm) + } + for _, cm := range b.Comment { + c.Comment = append(c.Comment, cm) + } + + return c +} + +type ConnOpenStrategy uint8 + +const ( + ConnOpenInOrder ConnOpenStrategy = iota + ConnOpenRoundRobin + ConnOpenRandom +) + +type Options struct { + client.PoolConfig + Debug bool + ClientInfo ClientInfo + Addr []string + ConnOpenStrategy ConnOpenStrategy + DialContext func(ctx context.Context, addr string) (net.Conn, error) + TLS *tls.Config + DialTimeout time.Duration // default 30 second + Debugf func(format string, v ...any) // only works when Debug is true + PoolMaxSize *int + PoolConnectionTimeoutInMs *int + PoolWaitToGetSessionTimeoutInMs *int + PoolEnableCompression *bool +} + +func ParseDSN(dsn string) (*Options, error) { + opt := &Options{} + err := opt.fromDSN(dsn) + if err != nil { + return nil, err + } + return opt, nil +} + +func (o *Options) fromDSN(in string) error { + dsn, err := url.Parse(in) + if err != nil { + return err + } + if dsn.Host == "" { + return errors.New("parse dsn address failed") + } + + if dsn.User != nil { + o.PoolConfig.UserName = dsn.User.Username() + o.PoolConfig.Password, _ = dsn.User.Password() + } + + //o.PoolConfig.Password = strings.TrimPrefix(dsn.Path, "/") + o.Addr = append(o.Addr, strings.Split(dsn.Host, ",")...) + + var params = dsn.Query() + + for v := range params { + switch v { + case "fetch_size": + { + fetchSize, err := strconv.ParseInt(params.Get(v), 10, 32) + if err != nil { + return errors.Wrap(err, "fetch size invalid value") + } + o.PoolConfig.FetchSize = int32(fetchSize) + break + } + case "time_zone": + { + o.PoolConfig.TimeZone = params.Get(v) + break + } + case "connect_retry_max": + { + connectRetryMax, err := strconv.ParseInt(params.Get(v), 10, 32) + if err != nil { + return errors.Wrap(err, "connect retry max invalid value") + } + o.PoolConfig.ConnectRetryMax = int(connectRetryMax) + break + } + case "username": + { + o.PoolConfig.UserName = params.Get(v) + break + } + case "password": + { + o.PoolConfig.Password = params.Get(v) + break + } + default: + { + break + } + } + } + return nil +} diff --git a/database/options_test.go b/database/options_test.go new file mode 100644 index 0000000..bb58da9 --- /dev/null +++ b/database/options_test.go @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ==================== Basic DSN Parsing ==================== + +// TestParseDSN_Basic verifies minimal DSN with host and port. +func TestParseDSN_Basic(t *testing.T) { + opt, err := ParseDSN("iotdb://root:root@127.0.0.1:6667") + require.NoError(t, err) + assert.Equal(t, []string{"127.0.0.1:6667"}, opt.Addr) + assert.Equal(t, "root", opt.PoolConfig.UserName) + assert.Equal(t, "root", opt.PoolConfig.Password) +} + +// TestParseDSN_WithQueryParams verifies DSN parameters are parsed correctly. +func TestParseDSN_WithQueryParams(t *testing.T) { + opt, err := ParseDSN("iotdb://user:pass@localhost:6667?fetch_size=1024&time_zone=UTC&connect_retry_max=5") + require.NoError(t, err) + assert.Equal(t, "user", opt.PoolConfig.UserName) + assert.Equal(t, "pass", opt.PoolConfig.Password) + assert.Equal(t, int32(1024), opt.PoolConfig.FetchSize) + assert.Equal(t, "UTC", opt.PoolConfig.TimeZone) + assert.Equal(t, 5, opt.PoolConfig.ConnectRetryMax) +} + +// TestParseDSN_MultipleAddresses verifies comma-separated addresses are split correctly. +func TestParseDSN_MultipleAddresses(t *testing.T) { + opt, err := ParseDSN("iotdb://root:root@host1:6667,host2:6667,host3:6667") + require.NoError(t, err) + assert.Equal(t, []string{"host1:6667", "host2:6667", "host3:6667"}, opt.Addr) +} + +// TestParseDSN_UsernamePasswordOverride verifies query params override URL credentials. +func TestParseDSN_UsernamePasswordOverride(t *testing.T) { + opt, err := ParseDSN("iotdb://old:old@localhost:6667?username=new_user&password=new_pass") + require.NoError(t, err) + assert.Equal(t, "new_user", opt.PoolConfig.UserName) + assert.Equal(t, "new_pass", opt.PoolConfig.Password) +} + +// TestParseDSN_NoCredentials verifies DSN without user info is valid. +func TestParseDSN_NoCredentials(t *testing.T) { + opt, err := ParseDSN("iotdb://127.0.0.1:6667") + require.NoError(t, err) + assert.Equal(t, []string{"127.0.0.1:6667"}, opt.Addr) + assert.Equal(t, "", opt.PoolConfig.UserName) + assert.Equal(t, "", opt.PoolConfig.Password) +} + +// ==================== DSN Error Cases ==================== + +// TestParseDSN_EmptyHost returns an error for DSN without host. +func TestParseDSN_EmptyHost(t *testing.T) { + _, err := ParseDSN("iotdb://") + assert.Error(t, err) +} + +// TestParseDSN_InvalidFetchSize returns an error for non-integer fetch_size. +func TestParseDSN_InvalidFetchSize(t *testing.T) { + _, err := ParseDSN("iotdb://root:root@localhost:6667?fetch_size=abc") + assert.Error(t, err) +} + +// TestParseDSN_InvalidConnectRetryMax returns an error for non-integer connect_retry_max. +func TestParseDSN_InvalidConnectRetryMax(t *testing.T) { + _, err := ParseDSN("iotdb://root:root@localhost:6667?connect_retry_max=xyz") + assert.Error(t, err) +} + +// TestParseDSN_InvalidURL returns an error for malformed URLs. +func TestParseDSN_InvalidURL(t *testing.T) { + _, err := ParseDSN("://not-a-url") + assert.Error(t, err) +} + +// ==================== Unknown params are silently ignored ==================== + +// TestParseDSN_UnknownParams verifies unknown query params don't cause errors. +func TestParseDSN_UnknownParams(t *testing.T) { + opt, err := ParseDSN("iotdb://root:root@localhost:6667?unknown_param=value") + require.NoError(t, err) + assert.Equal(t, "root", opt.PoolConfig.UserName) +} diff --git a/database/query_agg_test.go b/database/query_agg_test.go new file mode 100644 index 0000000..2467057 --- /dev/null +++ b/database/query_agg_test.go @@ -0,0 +1,436 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ==================== Aggregation Functions ==================== + +// TestQuery_AggCount verifies COUNT returns the correct row count. +func TestQuery_AggCount(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.agg_count") + defer cleanup() + + row := queryConn.QueryRowContext(context.Background(), + "SELECT COUNT(temp) FROM root.agg_count.d1") + var count int64 + err := row.Scan(&count) + require.NoError(t, err) + assert.Equal(t, int64(5), count) +} + +// TestQuery_AggSum verifies SUM aggregation. +func TestQuery_AggSum(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.agg_sum") + defer cleanup() + + row := queryConn.QueryRowContext(context.Background(), + "SELECT SUM(temp) FROM root.agg_sum.d1") + var sum float64 + err := row.Scan(&sum) + require.NoError(t, err) + assert.InDelta(t, 110.0, sum, 0.1) +} + +// TestQuery_AggAvg verifies AVG aggregation. +func TestQuery_AggAvg(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.agg_avg") + defer cleanup() + + row := queryConn.QueryRowContext(context.Background(), + "SELECT AVG(temp) FROM root.agg_avg.d1") + var avg float64 + err := row.Scan(&avg) + require.NoError(t, err) + assert.InDelta(t, 22.0, avg, 0.1) +} + +// TestQuery_AggMinMax verifies MIN_VALUE and MAX_VALUE aggregation. +func TestQuery_AggMinMax(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.agg_minmax") + defer cleanup() + + row := queryConn.QueryRowContext(context.Background(), + "SELECT MIN_VALUE(temp), MAX_VALUE(temp) FROM root.agg_minmax.d1") + var minTemp, maxTemp float64 + err := row.Scan(&minTemp, &maxTemp) + require.NoError(t, err) + assert.InDelta(t, 20.0, minTemp, 0.1) + assert.InDelta(t, 24.0, maxTemp, 0.1) +} + +// TestQuery_AggFirstLastValue verifies FIRST_VALUE and LAST_VALUE aggregation. +func TestQuery_AggFirstLastValue(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.agg_firstlast") + defer cleanup() + + row := queryConn.QueryRowContext(context.Background(), + "SELECT FIRST_VALUE(temp), LAST_VALUE(temp) FROM root.agg_firstlast.d1") + var first, last float64 + err := row.Scan(&first, &last) + require.NoError(t, err) + assert.InDelta(t, 20.0, first, 0.1) + assert.InDelta(t, 24.0, last, 0.1) +} + +// TestQuery_AggMinMaxTime verifies MIN_TIME and MAX_TIME return results. +func TestQuery_AggMinMaxTime(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.agg_mintime") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT MIN_TIME(temp), MAX_TIME(temp) FROM root.agg_mintime.d1") + require.NoError(t, err) + defer rows.Close() + + require.True(t, rows.Next()) + require.NoError(t, rows.Err()) +} + +// TestQuery_MultipleAggregations verifies multiple aggregation functions in one query. +func TestQuery_MultipleAggregations(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.agg_multi") + defer cleanup() + + row := queryConn.QueryRowContext(context.Background(), + "SELECT COUNT(temp), SUM(temp), AVG(temp), MIN_VALUE(temp), MAX_VALUE(temp) FROM root.agg_multi.d1") + var count int64 + var sum, avg, min, max float64 + err := row.Scan(&count, &sum, &avg, &min, &max) + require.NoError(t, err) + assert.Equal(t, int64(5), count) + assert.InDelta(t, 110.0, sum, 0.1) + assert.InDelta(t, 22.0, avg, 0.1) + assert.InDelta(t, 20.0, min, 0.1) + assert.InDelta(t, 24.0, max, 0.1) +} + +// ==================== GROUP BY ==================== + +// TestQuery_GroupByTimeInterval verifies time-interval grouping. +func TestQuery_GroupByTimeInterval(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.group_time") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT COUNT(temp), AVG(temp) FROM root.group_time.d1 GROUP BY ([1704067200000, 1704067500000), 2m)") + require.NoError(t, err) + defer rows.Close() + + assert.Greater(t, countRows(t, rows), 0) +} + +// TestQuery_GroupByWithWhere verifies GROUP BY combined with WHERE. +func TestQuery_GroupByWithWhere(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.group_where") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT COUNT(temp) FROM root.group_where.d1 WHERE time < 1704067440000 GROUP BY ([1704067200000, 1704067500000), 2m)") + require.NoError(t, err) + defer rows.Close() + + assert.Greater(t, countRows(t, rows), 0) +} + +// ==================== ORDER BY ==================== + +// TestQuery_OrderByTimeDesc verifies descending time ordering. +func TestQuery_OrderByTimeDesc(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.order_time") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT temp FROM root.order_time.d1 ORDER BY time DESC") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 5, countRows(t, rows)) +} + +// TestQuery_OrderByTimeAsc verifies ascending time ordering. +func TestQuery_OrderByTimeAsc(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.order_asc") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT temp FROM root.order_asc.d1 ORDER BY time ASC") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 5, countRows(t, rows)) +} + +// ==================== LIMIT and OFFSET ==================== + +// TestQuery_Limit verifies LIMIT restricts row count. +func TestQuery_Limit(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.limit_test") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT * FROM root.limit_test.d1 LIMIT 3") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 3, countRows(t, rows)) +} + +// TestQuery_LimitOffset verifies LIMIT with OFFSET skips rows correctly. +func TestQuery_LimitOffset(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.limit_offset") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT temp FROM root.limit_offset.d1 ORDER BY time ASC LIMIT 2 OFFSET 2") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 2, countRows(t, rows)) +} + +// ==================== ALIGN BY ==================== + +// TestQuery_AlignByDevice verifies cross-device alignment. +func TestQuery_AlignByDevice(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.align_device") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT temp FROM root.align_device.d1, root.align_device.d2 ALIGN BY DEVICE") + require.NoError(t, err) + defer rows.Close() + + assert.GreaterOrEqual(t, countRows(t, rows), 1) +} + +// TestQuery_AlignByTime verifies time-based alignment. +func TestQuery_AlignByTime(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.align_time") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT * FROM root.align_time.d1 ALIGN BY TIME") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 5, countRows(t, rows)) +} + +// ==================== LAST Query ==================== + +// TestQuery_LastSingleColumn verifies LAST returns one row per column. +func TestQuery_LastSingleColumn(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.last_single") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT LAST temp FROM root.last_single.d1") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 1, countRows(t, rows)) +} + +// TestQuery_LastMultipleColumns verifies LAST for multiple columns. +func TestQuery_LastMultipleColumns(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.last_multi") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT LAST temp, hum FROM root.last_multi.d1") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 2, countRows(t, rows)) +} + +// TestQuery_LastAllColumns verifies LAST * returns all columns. +func TestQuery_LastAllColumns(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.last_all") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT LAST * FROM root.last_all.d1") + require.NoError(t, err) + defer rows.Close() + + assert.GreaterOrEqual(t, countRows(t, rows), 3) +} + +// TestQuery_LastWithWhere verifies LAST with time filter. +func TestQuery_LastWithWhere(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.last_where") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT LAST temp FROM root.last_where.d1 WHERE time >= 1704067320000") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 1, countRows(t, rows)) +} + +// ==================== Expressions ==================== + +// TestQuery_ArithmeticExpression verifies arithmetic in SELECT. +func TestQuery_ArithmeticExpression(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.expr_arith") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT temp + 10, temp * 2, temp / 2 FROM root.expr_arith.d1") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 5, countRows(t, rows)) +} + +// TestQuery_AggregationExpression verifies arithmetic on aggregation results. +func TestQuery_AggregationExpression(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.expr_agg") + defer cleanup() + + row := queryConn.QueryRowContext(context.Background(), + "SELECT AVG(temp) + 1, SUM(temp) * 2 FROM root.expr_agg.d1") + var avgPlus, sumTimes float64 + err := row.Scan(&avgPlus, &sumTimes) + require.NoError(t, err) + assert.InDelta(t, 23.0, avgPlus, 0.1) + assert.InDelta(t, 220.0, sumTimes, 0.1) +} + +// ==================== Metadata ==================== + +// TestQuery_ShowDatabases verifies SHOW DATABASES returns results. +func TestQuery_ShowDatabases(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + rows, err := queryConn.QueryContext(context.Background(), "SHOW DATABASES") + require.NoError(t, err) + defer rows.Close() + + assert.GreaterOrEqual(t, countRows(t, rows), 1) +} + +// TestQuery_ShowTimeseries verifies SHOW TIMESERIES returns correct count. +func TestQuery_ShowTimeseries(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.show_ts") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SHOW TIMESERIES root.show_ts.**") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 4, countRows(t, rows)) +} + +// TestQuery_ColumnsMetadata verifies column names are retrievable from result set. +func TestQuery_ColumnsMetadata(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.col_meta") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT * FROM root.col_meta.d1") + require.NoError(t, err) + defer rows.Close() + + cols, err := rows.Columns() + require.NoError(t, err) + assert.GreaterOrEqual(t, len(cols), 3) +} diff --git a/database/query_parameters.go b/database/query_parameters.go new file mode 100644 index 0000000..a4b55d9 --- /dev/null +++ b/database/query_parameters.go @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "errors" + "regexp" + "time" +) + +var ( + ErrInvalidValueInNamedDateValue = errors.New("invalid value in NamedDateValue for query parameter") + ErrUnsupportedQueryParameter = errors.New("unsupported query parameter type") + + hasQueryParamsRe = regexp.MustCompile("{.+:.+}") +) + +func bindQueryOrAppendParameters(options *QueryOptions, query string, timezone *time.Location, args ...any) (string, error) { + // prefer native query parameters over legacy bind if query parameters provided explicit + if len(options.parameters) > 0 { + return query, nil + } + + return bind(timezone, query, args...) +} diff --git a/database/query_test.go b/database/query_test.go new file mode 100644 index 0000000..f30cff8 --- /dev/null +++ b/database/query_test.go @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "context" + "database/sql" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var queryConn *sql.DB + +func init() { + var err error + queryConn, err = sql.Open("iotdb", "iotdb://root:root@127.0.0.1:6667") + if err != nil { + panic(err) + } +} + +// setupQueryTestData creates a storage group with two devices and sample data. +// Returns a cleanup function that deletes the storage group. +func setupQueryTestData(t *testing.T, db *sql.DB, sg string) func() { + _, err := db.ExecContext(context.Background(), "CREATE DATABASE "+sg) + require.NoError(t, err) + + _, err = db.ExecContext(context.Background(), + "CREATE TIMESERIES "+sg+".d1.temp WITH DATATYPE=FLOAT, ENCODING=PLAIN") + require.NoError(t, err) + + _, err = db.ExecContext(context.Background(), + "CREATE TIMESERIES "+sg+".d1.hum WITH DATATYPE=DOUBLE, ENCODING=PLAIN") + require.NoError(t, err) + + _, err = db.ExecContext(context.Background(), + "CREATE TIMESERIES "+sg+".d1.status WITH DATATYPE=BOOLEAN, ENCODING=PLAIN") + require.NoError(t, err) + + _, err = db.ExecContext(context.Background(), + "CREATE TIMESERIES "+sg+".d2.temp WITH DATATYPE=FLOAT, ENCODING=PLAIN") + require.NoError(t, err) + + _, err = db.ExecContext(context.Background(), + "INSERT INTO "+sg+".d1(timestamp, temp, hum, status) VALUES (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?)", + 1704067200000, 20.0, 60.0, true, + 1704067260000, 21.0, 61.0, false, + 1704067320000, 22.0, 62.0, true, + 1704067380000, 23.0, 63.0, false, + 1704067440000, 24.0, 64.0, true) + require.NoError(t, err) + + _, err = db.ExecContext(context.Background(), + "INSERT INTO "+sg+".d2(timestamp, temp) VALUES (?, ?), (?, ?), (?, ?)", + 1704067200000, 30.0, + 1704067260000, 31.0, + 1704067320000, 32.0) + require.NoError(t, err) + + return func() { + _, _ = db.ExecContext(context.Background(), "DELETE DATABASE "+sg) + } +} + +// countRows iterates all rows and returns the count, asserting no iteration error. +func countRows(t *testing.T, rows *sql.Rows) int { + var count int + for rows.Next() { + count++ + } + require.NoError(t, rows.Err()) + return count +} + +// ==================== SELECT Basics ==================== + +// TestQuery_SelectSingleColumn verifies selecting one column returns all rows. +func TestQuery_SelectSingleColumn(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.select_single") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT temp FROM root.select_single.d1") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 5, countRows(t, rows)) +} + +// TestQuery_SelectMultipleColumns verifies column metadata for multi-column select. +func TestQuery_SelectMultipleColumns(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.select_multi") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT temp, hum, status FROM root.select_multi.d1") + require.NoError(t, err) + defer rows.Close() + + cols, err := rows.Columns() + require.NoError(t, err) + assert.GreaterOrEqual(t, len(cols), 3) +} + +// TestQuery_SelectAll verifies SELECT * returns all rows. +func TestQuery_SelectAll(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.select_all") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT * FROM root.select_all.d1") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 5, countRows(t, rows)) +} + +// TestQuery_SelectWithAlias verifies column aliasing works. +func TestQuery_SelectWithAlias(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.select_alias") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT temp AS temperature, hum AS humidity FROM root.select_alias.d1") + require.NoError(t, err) + defer rows.Close() + + cols, err := rows.Columns() + require.NoError(t, err) + assert.GreaterOrEqual(t, len(cols), 2) +} + +// ==================== WHERE Clause ==================== + +// TestQuery_WhereTimeFilter verifies time-based filtering. +func TestQuery_WhereTimeFilter(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.where_time") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT temp FROM root.where_time.d1 WHERE time >= 1704067260000") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 4, countRows(t, rows)) +} + +// TestQuery_WhereTimeRange verifies time range filtering with AND. +func TestQuery_WhereTimeRange(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.where_range") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT temp FROM root.where_range.d1 WHERE time >= 1704067260000 AND time < 1704067380000") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 2, countRows(t, rows)) +} + +// TestQuery_WhereValueFilter verifies value-based filtering. +func TestQuery_WhereValueFilter(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.where_value") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT temp FROM root.where_value.d1 WHERE temp > 22.0") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 2, countRows(t, rows)) +} + +// TestQuery_WhereBooleanFilter verifies boolean value filtering. +func TestQuery_WhereBooleanFilter(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.where_bool") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT temp FROM root.where_bool.d1 WHERE status = true") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 3, countRows(t, rows)) +} + +// TestQuery_WhereCombinedFilter verifies compound WHERE with time, value, and boolean. +func TestQuery_WhereCombinedFilter(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupQueryTestData(t, queryConn, "root.where_combined") + defer cleanup() + + rows, err := queryConn.QueryContext(context.Background(), + "SELECT temp FROM root.where_combined.d1 WHERE time > 1704067260000 AND temp > 22.0 AND status = true") + require.NoError(t, err) + defer rows.Close() + + assert.Equal(t, 1, countRows(t, rows)) +} diff --git a/database/rows.go b/database/rows.go new file mode 100644 index 0000000..dfe29dc --- /dev/null +++ b/database/rows.go @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "errors" + + "github.com/apache/iotdb-client-go/v2/client" + "github.com/apache/iotdb-client-go/v2/database/column" +) + +type rows struct { + set *client.SessionDataSet + columns []column.Interface +} + +func (r *rows) Next() (bool, error) { + if r.set == nil { + return false, errors.New("rows is nil") + } + + return r.set.Next() +} diff --git a/database/rows_std.go b/database/rows_std.go new file mode 100644 index 0000000..5f0ef69 --- /dev/null +++ b/database/rows_std.go @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "database/sql/driver" + "io" + + "github.com/pkg/errors" +) + +type stdRows struct { + rows *rows + debugf func(format string, v ...any) +} + +// Columns returns the names of the columns. The number of +// columns of the result is inferred from the length of the +// slice. If a particular column name isn't known, an empty +// string should be returned for that entry. +func (s *stdRows) Columns() []string { + return s.rows.set.GetColumnNames() +} + +// Close closes the rows iterator. +func (s *stdRows) Close() error { + s.rows.set.Close() + return nil +} + +// Next is called to populate the next row of data into +// the provided slice. The provided slice will be the same +// size as the Columns() are wide. +// +// Next should return io.EOF when there are no more rows. +// +// The dest should not be written to outside of Next. Care +// should be taken when closing Rows not to modify +// a buffer held in dest. +func (s *stdRows) Next(dest []driver.Value) error { + if len(s.rows.set.GetColumnNames()) != len(dest) { + return errors.New("column count mismatch") + } + next, err := s.rows.Next() + if err != nil { + s.debugf("rows.Next() failed: %v", err) + return err + } + if next { + for i := range dest { + if s.rows.columns[i] == nil { + dest[i] = nil + continue + } + switch value := s.rows.columns[i].Row(s.rows.set, false).(type) { + case driver.Valuer: + v, err := value.Value() + if err != nil { + s.debugf("Next row error: %v\n", err) + return err + } + dest[i] = v + default: + dest[i] = value + break + } + } + } else { + return io.EOF + } + return nil +} diff --git a/database/scan_types_test.go b/database/scan_types_test.go new file mode 100644 index 0000000..3a17e3c --- /dev/null +++ b/database/scan_types_test.go @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "context" + "database/sql" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var scanConn *sql.DB + +func init() { + var err error + scanConn, err = sql.Open("iotdb", "iotdb://root:root@127.0.0.1:6667") + if err != nil { + panic(err) + } +} + +// setupScanTestData creates a database with all supported data types and inserts one row. +func setupScanTestData(t *testing.T, db *sql.DB, sg string) func() { + _, err := db.ExecContext(context.Background(), "CREATE DATABASE "+sg) + require.NoError(t, err) + + timeseries := []struct { + name string + dataType string + }{ + {"int32_val", "INT32"}, + {"int64_val", "INT64"}, + {"float_val", "FLOAT"}, + {"double_val", "DOUBLE"}, + {"bool_val", "BOOLEAN"}, + {"text_val", "TEXT"}, + } + for _, ts := range timeseries { + _, err = db.ExecContext(context.Background(), + "CREATE TIMESERIES "+sg+".d1."+ts.name+" WITH DATATYPE="+ts.dataType+", ENCODING=PLAIN") + require.NoError(t, err) + } + + _, err = db.ExecContext(context.Background(), + "INSERT INTO "+sg+".d1(timestamp, int32_val, int64_val, float_val, double_val, bool_val, text_val) "+ + "VALUES (?, ?, ?, ?, ?, ?, ?)", + 1704067200000, 42, 1718193600000, 3.14, 2.718281828, true, "hello world") + require.NoError(t, err) + + return func() { + _, _ = db.ExecContext(context.Background(), "DELETE DATABASE "+sg) + } +} + +// ==================== Individual Type Scan Tests ==================== + +// TestScan_Int32 verifies INT32 column scans correctly as int32. +func TestScan_Int32(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupScanTestData(t, scanConn, "root.scan_int32") + defer cleanup() + + rows, err := scanConn.QueryContext(context.Background(), + "SELECT int32_val FROM root.scan_int32.d1") + require.NoError(t, err) + defer rows.Close() + + require.True(t, rows.Next()) + var ts int64 + var val int32 + err = rows.Scan(&ts, &val) + require.NoError(t, err) + assert.Equal(t, int32(42), val) + require.NoError(t, rows.Err()) +} + +// TestScan_Int64 verifies INT64 column handles large values without truncation. +func TestScan_Int64(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupScanTestData(t, scanConn, "root.scan_int64") + defer cleanup() + + rows, err := scanConn.QueryContext(context.Background(), + "SELECT int64_val FROM root.scan_int64.d1") + require.NoError(t, err) + defer rows.Close() + + require.True(t, rows.Next()) + var ts int64 + var val int64 + err = rows.Scan(&ts, &val) + require.NoError(t, err) + assert.Equal(t, int64(1718193600000), val) + require.NoError(t, rows.Err()) +} + +// TestScan_Float verifies FLOAT column scans correctly. +func TestScan_Float(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupScanTestData(t, scanConn, "root.scan_float") + defer cleanup() + + rows, err := scanConn.QueryContext(context.Background(), + "SELECT float_val FROM root.scan_float.d1") + require.NoError(t, err) + defer rows.Close() + + require.True(t, rows.Next()) + var ts int64 + var val float64 + err = rows.Scan(&ts, &val) + require.NoError(t, err) + assert.InDelta(t, 3.14, val, 0.01) + require.NoError(t, rows.Err()) +} + +// TestScan_Double verifies DOUBLE column scans with full precision. +func TestScan_Double(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupScanTestData(t, scanConn, "root.scan_double") + defer cleanup() + + rows, err := scanConn.QueryContext(context.Background(), + "SELECT double_val FROM root.scan_double.d1") + require.NoError(t, err) + defer rows.Close() + + require.True(t, rows.Next()) + var ts int64 + var val float64 + err = rows.Scan(&ts, &val) + require.NoError(t, err) + assert.InDelta(t, 2.718281828, val, 0.000001) + require.NoError(t, rows.Err()) +} + +// TestScan_Boolean verifies BOOLEAN column scans as bool. +func TestScan_Boolean(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupScanTestData(t, scanConn, "root.scan_bool") + defer cleanup() + + rows, err := scanConn.QueryContext(context.Background(), + "SELECT bool_val FROM root.scan_bool.d1") + require.NoError(t, err) + defer rows.Close() + + require.True(t, rows.Next()) + var ts int64 + var val bool + err = rows.Scan(&ts, &val) + require.NoError(t, err) + assert.True(t, val) + require.NoError(t, rows.Err()) +} + +// TestScan_Text verifies TEXT/STRING column scans as string. +func TestScan_Text(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupScanTestData(t, scanConn, "root.scan_text") + defer cleanup() + + rows, err := scanConn.QueryContext(context.Background(), + "SELECT text_val FROM root.scan_text.d1") + require.NoError(t, err) + defer rows.Close() + + require.True(t, rows.Next()) + var ts int64 + var val string + err = rows.Scan(&ts, &val) + require.NoError(t, err) + assert.Equal(t, "hello world", val) + require.NoError(t, rows.Err()) +} + +// ==================== Special String Values ==================== + +// TestScan_TextSpecialChars verifies TEXT handles quotes, backslashes, and unicode. +func TestScan_TextSpecialChars(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + sg := "root.scan_text_special" + _, err := scanConn.ExecContext(context.Background(), "CREATE DATABASE "+sg) + require.NoError(t, err) + defer func() { _, _ = scanConn.ExecContext(context.Background(), "DELETE DATABASE "+sg) }() + + _, err = scanConn.ExecContext(context.Background(), + "CREATE TIMESERIES "+sg+".d1.msg WITH DATATYPE=TEXT, ENCODING=PLAIN") + require.NoError(t, err) + + testCases := []struct { + name string + ts int + value string + }{ + {"chinese", 1, "你好世界"}, + {"emoji", 2, "hello 🌍"}, + } + + for _, tc := range testCases { + _, err = scanConn.ExecContext(context.Background(), + "INSERT INTO "+sg+".d1(timestamp, msg) VALUES (?, ?)", tc.ts, tc.value) + require.NoError(t, err, "insert failed for %s", tc.name) + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rows, err := scanConn.QueryContext(context.Background(), + "SELECT msg FROM "+sg+".d1 WHERE time = ?", tc.ts) + require.NoError(t, err) + defer rows.Close() + + require.True(t, rows.Next()) + var ts int64 + var val string + err = rows.Scan(&ts, &val) + require.NoError(t, err) + assert.Equal(t, tc.value, val) + require.NoError(t, rows.Err()) + }) + } +} + +// ==================== Empty Result Set ==================== + +// TestScan_EmptyResult verifies rows.Next() returns false for empty result. +func TestScan_EmptyResult(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupScanTestData(t, scanConn, "root.scan_empty") + defer cleanup() + + rows, err := scanConn.QueryContext(context.Background(), + "SELECT int32_val FROM root.scan_empty.d1 WHERE time = 999999999999999") + require.NoError(t, err) + defer rows.Close() + + assert.False(t, rows.Next()) + require.NoError(t, rows.Err()) +} + +// ==================== Multiple Columns In Single Scan ==================== + +// TestScan_AllTypesInOneRow verifies scanning all supported types in a single row. +func TestScan_AllTypesInOneRow(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + cleanup := setupScanTestData(t, scanConn, "root.scan_all") + defer cleanup() + + rows, err := scanConn.QueryContext(context.Background(), + "SELECT int32_val, int64_val, float_val, double_val, bool_val, text_val FROM root.scan_all.d1") + require.NoError(t, err) + defer rows.Close() + + require.True(t, rows.Next()) + var ( + ts int64 + int32Val int32 + int64Val int64 + floatVal float64 + doubleVal float64 + boolVal bool + textVal string + ) + err = rows.Scan(&ts, &int32Val, &int64Val, &floatVal, &doubleVal, &boolVal, &textVal) + require.NoError(t, err) + + assert.Equal(t, int32(42), int32Val) + assert.Equal(t, int64(1718193600000), int64Val) + assert.InDelta(t, 3.14, floatVal, 0.01) + assert.InDelta(t, 2.718281828, doubleVal, 0.000001) + assert.True(t, boolVal) + assert.Equal(t, "hello world", textVal) + require.NoError(t, rows.Err()) +} diff --git a/database/std_conn_opener.go b/database/std_conn_opener.go new file mode 100644 index 0000000..3ac8972 --- /dev/null +++ b/database/std_conn_opener.go @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import ( + "context" + "database/sql/driver" + "fmt" + "log" + "math/rand" + "os" + "sync/atomic" +) + +type stdConnOpener struct { + err error + opt *Options + debugf func(format string, v ...any) +} + +func (o *stdConnOpener) Driver() driver.Driver { + var debugf = func(format string, v ...any) {} + if o.opt.Debug { + if o.opt.Debugf != nil { + debugf = o.opt.Debugf + } else { + debugf = log.New(os.Stdout, "[iotdb-std] ", 0).Printf + } + } + return &stdDriver{ + opt: o.opt, + debugf: debugf, + } +} + +func (o *stdConnOpener) Connect(ctx context.Context) (_ driver.Conn, err error) { + if o.err != nil { + o.debugf("[connect] opener error: %v\n", o.err) + return nil, o.err + } + var ( + conn stdConnect + connID = int(atomic.AddInt64(&globalConnID, 1)) + dialFunc func(ctx context.Context, addr string, num int, opt *Options) (stdConnect, error) + ) + + dialFunc = func(ctx context.Context, addr string, num int, opt *Options) (stdConnect, error) { + return dial(ctx, addr, num, opt) + } + + if o.opt.Addr == nil || len(o.opt.Addr) == 0 { + return nil, ErrAcquireConnNoAddress + } + + for i := range o.opt.Addr { + var num int + switch o.opt.ConnOpenStrategy { + case ConnOpenInOrder: + num = i + case ConnOpenRoundRobin: + num = (connID + i) % len(o.opt.Addr) + case ConnOpenRandom: + random := rand.Int() + num = (random + i) % len(o.opt.Addr) + } + if conn, err = dialFunc(ctx, o.opt.Addr[num], connID, o.opt); err == nil { + var debugf = func(format string, v ...any) {} + if o.opt.Debug { + if o.opt.Debugf != nil { + debugf = o.opt.Debugf + } else { + debugf = log.New(os.Stdout, fmt.Sprintf("[iotdb-std][conn=%d][%s] ", num, o.opt.Addr[num]), 0).Printf + } + } + return &stdDriver{ + opt: o.opt, + conn: conn, + debugf: debugf, + }, nil + } else { + o.debugf("[connect] error connecting to %s on connection %d: %v\n", o.opt.Addr[num], connID, err) + } + } + + return nil, err +} diff --git a/database/std_connect.go b/database/std_connect.go new file mode 100644 index 0000000..b816e73 --- /dev/null +++ b/database/std_connect.go @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package iotdb_go + +import "context" + +type stdConnect interface { + isBad() bool + close() error + query(ctx context.Context, release nativeTransportRelease, query string, args ...any) (*rows, error) + exec(ctx context.Context, query string, args ...any) error + ping(ctx context.Context) (err error) + commit() error +} diff --git a/go.mod b/go.mod index f3f6709..e8c5ef6 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.25 require ( github.com/apache/thrift v0.23.0 + github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.8.2 ) diff --git a/go.sum b/go.sum index ff6d3f8..b1dda41 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,9 @@ github.com/apache/thrift v0.23.0/go.mod h1:zPt6WxgvTOM6hF92y8C+MkEM5LMxZuk4JcQOi github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=