Skip to content

Commit

Permalink
Fix query extended query execution if there is no Bind step
Browse files Browse the repository at this point in the history
  • Loading branch information
exAspArk committed Jan 7, 2025
1 parent 351633e commit 153af8a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 18 deletions.
30 changes: 18 additions & 12 deletions src/query_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ type QueryHandler struct {
////////////////////////////////////////////////////////////////////////////////////////////////////

type PreparedStatement struct {
Name string
Query string
Statement *sql.Stmt
Variables []interface{}
Portal string
Rows *sql.Rows
Name string
Query string
Statement *sql.Stmt
ParameterOIDs []uint32
Variables []interface{}
Portal string
Rows *sql.Rows
}

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -222,9 +223,10 @@ func (queryHandler *QueryHandler) HandleParseQuery(message *pgproto3.Parse) ([]p
}

preparedStatement := &PreparedStatement{
Name: message.Name,
Query: query,
Statement: statement,
Name: message.Name,
Query: query,
Statement: statement,
ParameterOIDs: message.ParameterOIDs,
}

messages := []pgproto3.Message{&pgproto3.ParseComplete{}}
Expand Down Expand Up @@ -272,8 +274,8 @@ func (queryHandler *QueryHandler) HandleBindQuery(message *pgproto3.Bind, prepar
func (queryHandler *QueryHandler) HandleDescribeQuery(message *pgproto3.Describe, preparedStatement *PreparedStatement) ([]pgproto3.Message, *PreparedStatement, error) {
switch message.ObjectType {
case 'S': // Statement
if message.Name != preparedStatement.Query {
LogError(queryHandler.config, "Statement mismatch:", message.Name, "instead of", preparedStatement.Query)
if message.Name != preparedStatement.Name {
LogError(queryHandler.config, "Statement mismatch:", message.Name, "instead of", preparedStatement.Name)
return nil, nil, errors.New("statement mismatch")
}
case 'P': // Portal
Expand All @@ -283,6 +285,10 @@ func (queryHandler *QueryHandler) HandleDescribeQuery(message *pgproto3.Describe
}
}

if len(preparedStatement.ParameterOIDs) != len(preparedStatement.Variables) { // Bind step didn't happen before
return []pgproto3.Message{&pgproto3.NoData{}}, preparedStatement, nil
}

rows, err := preparedStatement.Statement.QueryContext(context.Background(), preparedStatement.Variables...)
if err != nil {
LogError(queryHandler.config, "Couldn't execute prepared statement via DuckDB:", preparedStatement.Query+"\n"+err.Error())
Expand All @@ -303,7 +309,7 @@ func (queryHandler *QueryHandler) HandleExecuteQuery(message *pgproto3.Execute,
return nil, errors.New("portal mismatch")
}

if preparedStatement.Rows == nil {
if preparedStatement.Rows == nil { // If Describe step didn't have Bind step before
rows, err := preparedStatement.Statement.QueryContext(context.Background(), preparedStatement.Variables...)
if err != nil {
LogError(queryHandler.config, "Couldn't execute prepared statement via DuckDB:", preparedStatement.Query+"\n"+err.Error())
Expand Down
27 changes: 21 additions & 6 deletions src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ func TestHandleQuery(t *testing.T) {
}

func TestHandleParseQuery(t *testing.T) {
t.Run("Handles PARSE extended query", func(t *testing.T) {
t.Run("Handles PARSE extended query step", func(t *testing.T) {
query := "SELECT usename, passwd FROM pg_shadow WHERE usename=$1"
queryHandler := initQueryHandler()
message := &pgproto3.Parse{Query: query}
Expand All @@ -916,7 +916,7 @@ func TestHandleParseQuery(t *testing.T) {
}

func TestHandleBindQuery(t *testing.T) {
t.Run("Handles BIND extended query with text format parameter", func(t *testing.T) {
t.Run("Handles BIND extended query step with text format parameter", func(t *testing.T) {
queryHandler := initQueryHandler()
query := "SELECT usename, passwd FROM pg_shadow WHERE usename=$1"
parseMessage := &pgproto3.Parse{Query: query}
Expand All @@ -941,7 +941,7 @@ func TestHandleBindQuery(t *testing.T) {
}
})

t.Run("Handles BIND extended query with binary format parameter", func(t *testing.T) {
t.Run("Handles BIND extended query step with binary format parameter", func(t *testing.T) {
queryHandler := initQueryHandler()
query := "SELECT c.oid FROM pg_catalog.pg_class c WHERE c.relnamespace = $1"
parseMessage := &pgproto3.Parse{Query: query}
Expand Down Expand Up @@ -972,10 +972,10 @@ func TestHandleBindQuery(t *testing.T) {
}

func TestHandleDescribeQuery(t *testing.T) {
t.Run("Handles DESCRIBE extended query", func(t *testing.T) {
t.Run("Handles DESCRIBE extended query step", func(t *testing.T) {
queryHandler := initQueryHandler()
query := "SELECT usename, passwd FROM pg_shadow WHERE usename=$1"
parseMessage := &pgproto3.Parse{Query: query}
parseMessage := &pgproto3.Parse{Query: query, ParameterOIDs: []uint32{pgtype.TextOID}}
_, preparedStatement, _ := queryHandler.HandleParseQuery(parseMessage)
bindMessage := &pgproto3.Bind{Parameters: [][]byte{[]byte("bemidb")}}
_, preparedStatement, _ = queryHandler.HandleBindQuery(bindMessage, preparedStatement)
Expand All @@ -992,10 +992,25 @@ func TestHandleDescribeQuery(t *testing.T) {
t.Errorf("Expected the prepared statement to have rows")
}
})

t.Run("Handles DESCRIBE (Statement) extended query step if there was no BIND step", func(t *testing.T) {
queryHandler := initQueryHandler()
query := "SELECT usename, passwd FROM pg_shadow WHERE usename=$1"
parseMessage := &pgproto3.Parse{Query: query, ParameterOIDs: []uint32{pgtype.TextOID}}
_, preparedStatement, _ := queryHandler.HandleParseQuery(parseMessage)
message := &pgproto3.Describe{ObjectType: 'S'}

messages, preparedStatement, err := queryHandler.HandleDescribeQuery(message, preparedStatement)

testNoError(t, err)
testMessageTypes(t, messages, []pgproto3.Message{
&pgproto3.NoData{},
})
})
}

func TestHandleExecuteQuery(t *testing.T) {
t.Run("Handles EXECUTE extended query", func(t *testing.T) {
t.Run("Handles EXECUTE extended query step", func(t *testing.T) {
queryHandler := initQueryHandler()
query := "SELECT usename, passwd FROM pg_shadow WHERE usename=$1"
parseMessage := &pgproto3.Parse{Query: query}
Expand Down

0 comments on commit 153af8a

Please sign in to comment.