Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GUID conversion #207

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Other supported formats are listed below.
* `multisubnetfailover`
* `true` (Default) Client attempt to connect to all IPs simultaneously.
* `false` Client attempts to connect to IPs in serial.
* `guid conversion` - Enables the conversion of GUIDs, so that byte order is preserved. UniqueIdentifier isn't supported for nullable fields, NullUniqueIdentifier must be used instead.

### Connection parameters for namedpipe package
* `pipe` - If set, no Browser query is made and named pipe used will be `\\<host>\pipe\<pipe>`
Expand Down
2 changes: 1 addition & 1 deletion alwaysencrypted_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
func testProviderErrorHandling(t *testing.T, name string, provider aecmk.ColumnEncryptionKeyProvider, sel string, insert string, insertArgs []interface{}) {
t.Helper()
testProvider := &testKeyProvider{fallback: provider}
connector, _ := getTestConnector(t)
connector, _ := getTestConnector(t, false /*guidConversion*/)
connector.RegisterCekProvider(name, testProvider)
conn := sql.OpenDB(connector)
defer conn.Close()
Expand Down
2 changes: 1 addition & 1 deletion bulkcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func (b *Bulk) createColMetadata() []byte {
}
binary.Write(buf, binary.LittleEndian, uint16(col.Flags))

writeTypeInfo(buf, &b.bulkColumns[i].ti, false)
writeTypeInfo(buf, &b.bulkColumns[i].ti, false, b.cn.sess.encoding)

if col.ti.TypeId == typeNText ||
col.ti.TypeId == typeText ||
Expand Down
12 changes: 10 additions & 2 deletions bulkcopy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ func TestBulkcopyWithInvalidNullableType(t *testing.T) {
}
}

func TestBulkcopy(t *testing.T) {
func testBulkcopy(t *testing.T, guidConversion bool) {
// TDS level Bulk Insert is not supported on Azure SQL Server.
if dsn := makeConnStr(t); strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") {
if dsn := makeConnStrSettingGuidConversion(t, guidConversion); strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") {
t.Skip("TDS level bulk copy is not supported on Azure SQL Server")
}
type testValue struct {
Expand Down Expand Up @@ -300,6 +300,14 @@ func TestBulkcopy(t *testing.T) {
}
}

func TestBulkcopyWithGuidConversion(t *testing.T) {
testBulkcopy(t, true /*guidConversion*/)
}

func TestBulkcopy(t *testing.T) {
testBulkcopy(t, false /*guidConversion*/)
}

func compareValue(a interface{}, expected interface{}) bool {
if got, ok := a.([]uint8); ok {
if _, ok := expected.([]uint8); !ok {
Expand Down
27 changes: 27 additions & 0 deletions msdsn/conn_str.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,14 @@ const (
Pipe = "pipe"
MultiSubnetFailover = "multisubnetfailover"
NoTraceID = "notraceid"
GuidConversion = "guid conversion"
)

type EncodeParameters struct {
// Properly convert GUIDs, using correct byte endianness
GuidConversion bool
}

type Config struct {
Port uint64
Host string
Expand Down Expand Up @@ -141,6 +147,8 @@ type Config struct {
// When true, no connection id or trace id value is sent in the prelogin packet.
// Some cloud servers may block connections that lack such values.
NoTraceID bool
// Parameters related to type encoding
Encoding EncodeParameters
}

func readDERFile(filename string) ([]byte, error) {
Expand Down Expand Up @@ -525,6 +533,20 @@ func Parse(dsn string) (Config, error) {
p.NoTraceID = notraceid
}
}

guidConversion, ok := params[GuidConversion]
if ok {
var err error
p.Encoding.GuidConversion, err = strconv.ParseBool(guidConversion)
if err != nil {
f := "invalid guid conversion '%s': %s"
return p, fmt.Errorf(f, guidConversion, err.Error())
}
} else {
// set to false for backward compatibility
p.Encoding.GuidConversion = false
}

return p, nil
}

Expand Down Expand Up @@ -585,6 +607,11 @@ func (p Config) URL() *url.URL {
if p.ColumnEncryption {
q.Add("columnencryption", "true")
}

if p.Encoding.GuidConversion {
q.Add(GuidConversion, strconv.FormatBool(p.Encoding.GuidConversion))
}

if len(q) > 0 {
res.RawQuery = q.Encode()
}
Expand Down
5 changes: 4 additions & 1 deletion msdsn/conn_str_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,10 @@ func TestValidConnectionString(t *testing.T) {
return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second && p.DisableRetry && !p.ColumnEncryption
}},
{"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1", func(p Config) bool {
return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption
return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption && !p.Encoding.GuidConversion
}},
{"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1&guid+conversion=true", func(p Config) bool {
return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption && p.Encoding.GuidConversion
}},
}
for _, ts := range connStrings {
Expand Down
2 changes: 1 addition & 1 deletion mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ func (s *Stmt) sendQuery(ctx context.Context, args []namedValue) (err error) {
params[0] = makeStrParam(s.query)
params[1] = makeStrParam(strings.Join(decls, ","))
}
if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil {
if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset, conn.sess.encoding); err != nil {
conn.sess.LogF(ctx, msdsn.LogErrors, "Failed to send Rpc with %v", err)
conn.connectionGood = false
return fmt.Errorf("failed to send RPC: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion mssql_go19.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) {
err = errCalTypes
return
}
res.buffer, err = val.encode(schema, name, columnStr, tvpFieldIndexes)
res.buffer, err = val.encode(schema, name, columnStr, tvpFieldIndexes, s.c.sess.encoding)
if err != nil {
return
}
Expand Down
22 changes: 16 additions & 6 deletions queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ func driverWithProcess(t *testing.T, tl Logger) *Driver {
}
}

func TestSelect(t *testing.T) {
conn, logger := open(t)
func testSelect(t *testing.T, guidConversion bool) {
conn, logger := openSettingGuidConversion(t, guidConversion)
defer conn.Close()
defer logger.StopLogging()

Expand All @@ -39,6 +39,10 @@ func TestSelect(t *testing.T) {
}

longstr := strings.Repeat("x", 10000)
expectedGuid := []byte{0x6F, 0x96, 0x19, 0xFF, 0x8B, 0x86, 0xD0, 0x11, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}
if guidConversion {
expectedGuid = []byte{0xFF, 0x19, 0x96, 0x6F, 0x86, 0x8B, 0x11, 0xD0, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}
}

values := []testStruct{
{"1", int64(1)},
Expand Down Expand Up @@ -83,8 +87,7 @@ func TestSelect(t *testing.T) {
{"cast('2079-06-06T23:59:00' as smalldatetime)",
time.Date(2079, 6, 6, 23, 59, 0, 0, time.UTC)},
{"cast(NULL as smalldatetime)", nil},
{"cast(0x6F9619FF8B86D011B42D00C04FC964FF as uniqueidentifier)",
[]byte{0x6F, 0x96, 0x19, 0xFF, 0x8B, 0x86, 0xD0, 0x11, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}},
{"cast(0x6F9619FF8B86D011B42D00C04FC964FF as uniqueidentifier)", expectedGuid},
{"cast(NULL as uniqueidentifier)", nil},
{"cast(0x1234 as varbinary(2))", []byte{0x12, 0x34}},
{"cast(N'abc' as nvarchar(max))", "abc"},
Expand Down Expand Up @@ -114,8 +117,7 @@ func TestSelect(t *testing.T) {
{"cast(cast(N'chào' as nvarchar(max)) collate Vietnamese_CI_AI as varchar(max))", "chào"}, // cp1258
{fmt.Sprintf("cast(N'%s' as nvarchar(max))", longstr), longstr},
{"cast(NULL as sql_variant)", nil},
{"cast(cast(0x6F9619FF8B86D011B42D00C04FC964FF as uniqueidentifier) as sql_variant)",
[]byte{0x6F, 0x96, 0x19, 0xFF, 0x8B, 0x86, 0xD0, 0x11, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}},
{"cast(cast(0x6F9619FF8B86D011B42D00C04FC964FF as uniqueidentifier) as sql_variant)", expectedGuid},
{"cast(cast(1 as bit) as sql_variant)", true},
{"cast(cast(10 as tinyint) as sql_variant)", int64(10)},
{"cast(cast(-10 as smallint) as sql_variant)", int64(-10)},
Expand Down Expand Up @@ -214,6 +216,14 @@ func TestSelect(t *testing.T) {
})
}

func TestSelectWithGuidConversion(t *testing.T) {
testSelect(t, true /*guidConversion*/)
}

func TestSelect(t *testing.T) {
testSelect(t, false /*guidConversion*/)
}

func TestSelectDateTimeOffset(t *testing.T) {
type testStruct struct {
sql string
Expand Down
8 changes: 5 additions & 3 deletions rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package mssql

import (
"encoding/binary"

"github.com/microsoft/go-mssqldb/msdsn"
)

type procId struct {
Expand Down Expand Up @@ -43,7 +45,7 @@ var (
)

// http://msdn.microsoft.com/en-us/library/dd357576.aspx
func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool) (err error) {
func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool, encoding msdsn.EncodeParameters) (err error) {
buf.BeginPacket(packRPCRequest, resetSession)
writeAllHeaders(buf, headers)
if len(proc.name) == 0 {
Expand Down Expand Up @@ -73,7 +75,7 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16,
if err = binary.Write(buf, binary.LittleEndian, param.Flags); err != nil {
return
}
err = writeTypeInfo(buf, &param.ti, (param.Flags&fByRevValue) != 0)
err = writeTypeInfo(buf, &param.ti, (param.Flags&fByRevValue) != 0, encoding)
if err != nil {
return
}
Expand All @@ -82,7 +84,7 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16,
return
}
if (param.Flags & fEncrypted) == fEncrypted {
err = writeTypeInfo(buf, &param.tiOriginal, false)
err = writeTypeInfo(buf, &param.tiOriginal, false, encoding)
if err != nil {
return
}
Expand Down
1 change: 1 addition & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ func newSession(outbuf *tdsBuffer, logger ContextLogger, p msdsn.Config) *tdsSes
logger: logger,
logFlags: uint64(p.LogFlags),
aeSettings: &alwaysEncryptedSettings{keyProviders: aecmk.GetGlobalCekProviders()},
encoding: p.Encoding,
}
_ = sess.activityid.Scan(p.ActivityID)
// generating a guid has a small chance of failure. Make a best effort
Expand Down
1 change: 1 addition & 0 deletions tds.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ type tdsSession struct {
aeSettings *alwaysEncryptedSettings
connid UniqueIdentifier
activityid UniqueIdentifier
encoding msdsn.EncodeParameters
}

type alwaysEncryptedSettings struct {
Expand Down
14 changes: 10 additions & 4 deletions tds_go110_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,22 @@ import (
"testing"
)

func open(t testing.TB) (*sql.DB, *testLogger) {
connector, logger := getTestConnector(t)
func openSettingGuidConversion(t testing.TB, guidConversion bool) (*sql.DB, *testLogger) {
connector, logger := getTestConnector(t, guidConversion)
conn := sql.OpenDB(connector)
return conn, logger
}

func getTestConnector(t testing.TB) (*Connector, *testLogger) {
func open(t testing.TB) (*sql.DB, *testLogger) {
return openSettingGuidConversion(t, false /*guidConversion*/)
}

func getTestConnector(t testing.TB, guidConversion bool) (*Connector, *testLogger) {
tl := testLogger{t: t}
SetLogger(&tl)
connector, err := NewConnector(makeConnStr(t).String())

connectionString := makeConnStrSettingGuidConversion(t, guidConversion).String()
connector, err := NewConnector(connectionString)
if err != nil {
t.Error("Open connection failed:", err.Error())
return nil, &tl
Expand Down
9 changes: 7 additions & 2 deletions tds_go110pre_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build !go1.10
// +build !go1.10

package mssql
Expand All @@ -7,14 +8,18 @@ import (
"testing"
)

func open(t *testing.T) (*sql.DB, *testLogger) {
func openSettingGuidConversion(t *testing.T, guidConversion bool) (*sql.DB, *testLogger) {
tl := testLogger{t: t}
SetLogger(&tl)
checkConnStr(t)
conn, err := sql.Open("sqlserver", makeConnStr(t).String())
conn, err := sql.Open("sqlserver", makeConnStrSettingGuidConversion(t, guidConversion).String())
if err != nil {
t.Error("Open connection failed:", err.Error())
return nil, &tl
}
return conn, &tl
}

func open(t *testing.T) (*sql.DB, *testLogger) {
return openSettingGuidConversion(t, false /*guidConversion*/)
}
6 changes: 6 additions & 0 deletions tds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,12 @@ func makeConnStr(t testing.TB) *url.URL {
return testConnParams(t).URL()
}

func makeConnStrSettingGuidConversion(t testing.TB, guidConversion bool) *url.URL {
config := testConnParams(t)
config.Encoding.GuidConversion = guidConversion
return config.URL()
}

type testLogger struct {
t testing.TB
mu sync.Mutex
Expand Down
12 changes: 6 additions & 6 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ func parseColMetadata72(r *tdsBuffer, s *tdsSession) (columns []columnStruct) {
for i := range columns {
column := &columns[i]
baseTi := getBaseTypeInfo(r, true)
typeInfo := readTypeInfo(r, baseTi.TypeId, column.cryptoMeta)
typeInfo := readTypeInfo(r, baseTi.TypeId, column.cryptoMeta, s.encoding)
typeInfo.UserType = baseTi.UserType
typeInfo.Flags = baseTi.Flags
typeInfo.TypeId = baseTi.TypeId
Expand All @@ -621,7 +621,7 @@ func parseColMetadata72(r *tdsBuffer, s *tdsSession) (columns []columnStruct) {

if column.isEncrypted() && s.alwaysEncrypted {
// Read Crypto Metadata
cryptoMeta := parseCryptoMetadata(r, cekTable)
cryptoMeta := parseCryptoMetadata(r, cekTable, s.encoding)
cryptoMeta.typeInfo.Flags = baseTi.Flags
column.cryptoMeta = &cryptoMeta
} else {
Expand Down Expand Up @@ -657,14 +657,14 @@ type cryptoMetadata struct {
typeInfo typeInfo
}

func parseCryptoMetadata(r *tdsBuffer, cekTable *cekTable) cryptoMetadata {
func parseCryptoMetadata(r *tdsBuffer, cekTable *cekTable, encoding msdsn.EncodeParameters) cryptoMetadata {
ordinal := uint16(0)
if cekTable != nil {
ordinal = r.uint16()
}

typeInfo := getBaseTypeInfo(r, false)
ti := readTypeInfo(r, typeInfo.TypeId, nil)
ti := readTypeInfo(r, typeInfo.TypeId, nil, encoding)
ti.UserType = typeInfo.UserType
ti.Flags = typeInfo.Flags
ti.TypeId = typeInfo.TypeId
Expand Down Expand Up @@ -929,11 +929,11 @@ func parseReturnValue(r *tdsBuffer, s *tdsSession) (nv namedValue) {

var cryptoMetadata *cryptoMetadata = nil
if s.alwaysEncrypted && (ti.Flags&fEncrypted) == fEncrypted {
cm := parseCryptoMetadata(r, nil) // CryptoMetadata
cm := parseCryptoMetadata(r, nil, s.encoding) // CryptoMetadata
cryptoMetadata = &cm
}

ti2 := readTypeInfo(r, ti.TypeId, cryptoMetadata)
ti2 := readTypeInfo(r, ti.TypeId, cryptoMetadata, s.encoding)
nv.Value = ti2.Reader(&ti2, r, cryptoMetadata)

return
Expand Down
6 changes: 4 additions & 2 deletions tvp_go19.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"reflect"
"strings"
"time"

"github.com/microsoft/go-mssqldb/msdsn"
)

const (
Expand Down Expand Up @@ -62,7 +64,7 @@ func (tvp TVP) check() error {
return nil
}

func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldIndexes []int) ([]byte, error) {
func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldIndexes []int, encoding msdsn.EncodeParameters) ([]byte, error) {
if len(columnStr) != len(tvpFieldIndexes) {
return nil, ErrorWrongTyping
}
Expand All @@ -80,7 +82,7 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
for i, column := range columnStr {
binary.Write(buf, binary.LittleEndian, column.UserType)
binary.Write(buf, binary.LittleEndian, column.Flags)
writeTypeInfo(buf, &columnStr[i].ti, false)
writeTypeInfo(buf, &columnStr[i].ti, false, encoding)
writeBVarChar(buf, "")
}
// The returned error is always nil
Expand Down
Loading