diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/IntegerParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/IntegerParser.java index e8676706b..7f2095a51 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/IntegerParser.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/IntegerParser.java @@ -17,6 +17,8 @@ import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; +import java.math.BigDecimal; +import java.math.RoundingMode; import org.postgresql.util.ByteConverter; /** Translate from wire protocol to int. */ @@ -36,7 +38,8 @@ class IntegerParser extends Parser { case TEXT: String stringValue = new String(item); try { - this.item = Integer.valueOf(stringValue); + this.item = + new BigDecimal(stringValue).setScale(0, RoundingMode.HALF_UP).intValueExact(); } catch (Exception exception) { throw PGExceptionFactory.newPGException("Invalid int4 value: " + stringValue); } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/LongParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/LongParser.java index 2cc039425..45bcfe028 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/LongParser.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/LongParser.java @@ -21,6 +21,8 @@ import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.pgadapter.ProxyServer.DataFormat; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; +import java.math.BigDecimal; +import java.math.RoundingMode; import java.nio.charset.StandardCharsets; import javax.annotation.Nonnull; import org.postgresql.util.ByteConverter; @@ -43,7 +45,8 @@ public class LongParser extends Parser { case TEXT: String stringValue = new String(item); try { - this.item = Long.valueOf(stringValue); + this.item = + new BigDecimal(stringValue).setScale(0, RoundingMode.HALF_UP).longValueExact(); } catch (Exception exception) { throw PGExceptionFactory.newPGException("Invalid int8 value: " + stringValue); } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/ShortParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/ShortParser.java index 6147c2211..cf61db3ad 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/ShortParser.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/ShortParser.java @@ -16,6 +16,8 @@ import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; +import java.math.BigDecimal; +import java.math.RoundingMode; import org.postgresql.util.ByteConverter; /** Translate from wire protocol to short. */ @@ -31,9 +33,10 @@ class ShortParser extends Parser { case TEXT: String stringValue = new String(item); try { - this.item = Short.valueOf(stringValue); + this.item = + new BigDecimal(stringValue).setScale(0, RoundingMode.HALF_UP).shortValueExact(); } catch (Exception exception) { - throw PGExceptionFactory.newPGException("Invalid int4 value: " + stringValue); + throw PGExceptionFactory.newPGException("Invalid int2 value: " + stringValue); } break; case BINARY: diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/TimestampParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/TimestampParser.java index b013abcfb..29183be8e 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/parsers/TimestampParser.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/parsers/TimestampParser.java @@ -27,8 +27,11 @@ import com.google.common.base.Preconditions; import java.nio.charset.StandardCharsets; import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; import java.time.OffsetDateTime; import java.time.ZoneId; +import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatterBuilder; import java.time.temporal.ChronoField; @@ -71,7 +74,7 @@ public class TimestampParser extends Parser { .appendOffset(OptionsMetadata.isJava8() ? "+HH:mm" : "+HH:mm:ss", "+00") .toFormatter(); - private static final DateTimeFormatter TIMESTAMP_INPUT_FORMATTER = + private static final DateTimeFormatter TIMESTAMPTZ_INPUT_FORMATTER = new DateTimeFormatterBuilder() .parseLenient() .parseCaseInsensitive() @@ -80,6 +83,13 @@ public class TimestampParser extends Parser { // Java 8 does not support seconds in timezone offset. .appendOffset(OptionsMetadata.isJava8() ? "+HH:mm" : "+HH:mm:ss", "+00:00:00") .toFormatter(); + private static final DateTimeFormatter TIMESTAMP_INPUT_FORMATTER = + new DateTimeFormatterBuilder() + .parseLenient() + .parseCaseInsensitive() + .appendPattern("yyyy-MM-dd[[ ]['T']HH:mm[:ss][XXX]]") + .appendFraction(ChronoField.NANO_OF_SECOND, 0, 9, true) + .toFormatter(); private final SessionState sessionState; @@ -98,7 +108,7 @@ public class TimestampParser extends Parser { if (item != null) { switch (formatCode) { case TEXT: - this.item = toTimestamp(new String(item, StandardCharsets.UTF_8)); + this.item = toTimestamp(new String(item, StandardCharsets.UTF_8), sessionState); break; case BINARY: this.item = toTimestamp(item); @@ -123,15 +133,37 @@ public static Timestamp toTimestamp(@Nonnull byte[] data) { } /** Converts the given string value to a {@link Timestamp}. */ - public static Timestamp toTimestamp(String value) { + public static Timestamp toTimestamp(String value, SessionState sessionState) { try { String stringValue = toPGString(value); - TemporalAccessor temporalAccessor = TIMESTAMP_INPUT_FORMATTER.parse(stringValue); + TemporalAccessor temporalAccessor = TIMESTAMPTZ_INPUT_FORMATTER.parse(stringValue); return Timestamp.ofTimeSecondsAndNanos( temporalAccessor.getLong(ChronoField.INSTANT_SECONDS), temporalAccessor.get(ChronoField.NANO_OF_SECOND)); - } catch (Exception exception) { - throw PGExceptionFactory.newPGException("Invalid timestamp value: " + value); + } catch (Exception ignore) { + try { + TemporalAccessor temporalAccessor = + TIMESTAMP_INPUT_FORMATTER.parseBest( + value, ZonedDateTime::from, LocalDateTime::from, LocalDate::from); + ZonedDateTime zonedDateTime = null; + if (temporalAccessor instanceof ZonedDateTime) { + zonedDateTime = (ZonedDateTime) temporalAccessor; + } else if (temporalAccessor instanceof LocalDateTime) { + LocalDateTime localDateTime = (LocalDateTime) temporalAccessor; + zonedDateTime = localDateTime.atZone(sessionState.getTimezone()); + } else if (temporalAccessor instanceof LocalDate) { + LocalDate localDate = (LocalDate) temporalAccessor; + zonedDateTime = localDate.atStartOfDay().atZone(sessionState.getTimezone()); + } + if (zonedDateTime != null) { + return Timestamp.ofTimeSecondsAndNanos( + zonedDateTime.getLong(ChronoField.INSTANT_SECONDS), + zonedDateTime.get(ChronoField.NANO_OF_SECOND)); + } + throw PGExceptionFactory.newPGException("Invalid timestamp value: " + value); + } catch (Exception exception) { + throw PGExceptionFactory.newPGException("Invalid timestamp value: " + value); + } } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/session/CopySettings.java b/src/main/java/com/google/cloud/spanner/pgadapter/session/CopySettings.java index 3c9de5ce9..3be3397c7 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/session/CopySettings.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/session/CopySettings.java @@ -78,6 +78,11 @@ public CopySettings(SessionState sessionState) { this.sessionState = sessionState; } + /** Returns the underlying session state for these copy settings. */ + public SessionState getSessionState() { + return sessionState; + } + /** Returns the maximum number of parallel transactions for a single COPY operation. */ public int getMaxParallelism() { return sessionState.getIntegerSetting("spanner", "copy_max_parallelism", 128); diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/ExecuteStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/ExecuteStatement.java index f8fa71d90..9df015db6 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/ExecuteStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/ExecuteStatement.java @@ -24,6 +24,8 @@ import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; +import com.google.cloud.spanner.pgadapter.session.SessionState; +import com.google.cloud.spanner.pgadapter.statements.LiteralParser.Literal; import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName; import com.google.cloud.spanner.pgadapter.wireprotocol.BindMessage; import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage.ManuallyCreatedToken; @@ -71,7 +73,13 @@ public ExecuteStatement( NO_PARAMS, ImmutableList.of(), ImmutableList.of()); - this.executeStatement = parse(originalStatement.getSql()); + this.executeStatement = + parse( + originalStatement.getSql(), + connectionHandler + .getExtendedQueryProtocolHandler() + .getBackendConnection() + .getSessionState()); } @Override @@ -126,7 +134,7 @@ public IntermediatePortalStatement createPortal( return this; } - static ParsedExecuteStatement parse(String sql) { + static ParsedExecuteStatement parse(String sql, SessionState sessionState) { Preconditions.checkNotNull(sql); SimpleParser parser = new SimpleParser(sql); @@ -139,7 +147,7 @@ static ParsedExecuteStatement parse(String sql) { } String statementName = unquoteOrFoldIdentifier(name.name); - List parameters; + List parameters; if (parser.eatToken("(")) { List parametersList = parser.parseExpressionList(); if (parametersList == null || parametersList.isEmpty()) { @@ -149,7 +157,9 @@ static ParsedExecuteStatement parse(String sql) { throw PGExceptionFactory.newPGException("missing closing parentheses in parameters list"); } parameters = - parametersList.stream().map(ExecuteStatement::unquoteString).collect(Collectors.toList()); + parametersList.stream() + .map(LiteralParser::parseSingleConstantLiteralExpression) + .collect(Collectors.toList()); } else { parameters = Collections.emptyList(); } @@ -161,7 +171,11 @@ static ParsedExecuteStatement parse(String sql) { return new ParsedExecuteStatement( statementName, parameters.stream() - .map(p -> p == null ? null : p.getBytes(StandardCharsets.UTF_8)) + .map( + p -> + p == null + ? null + : p.getConvertedValue(sessionState).getBytes(StandardCharsets.UTF_8)) .toArray(byte[][]::new)); } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/LiteralParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/LiteralParser.java new file mode 100644 index 000000000..763be117e --- /dev/null +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/LiteralParser.java @@ -0,0 +1,429 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.statements; + +import static com.google.cloud.spanner.pgadapter.statements.SimpleParser.DOLLAR; +import static com.google.cloud.spanner.pgadapter.statements.SimpleParser.DOUBLE_QUOTE; +import static com.google.cloud.spanner.pgadapter.statements.SimpleParser.SINGLE_QUOTE; +import static com.google.cloud.spanner.pgadapter.statements.SimpleParser.isValidIdentifierChar; +import static com.google.cloud.spanner.pgadapter.statements.SimpleParser.isValidIdentifierFirstChar; + +import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; +import com.google.cloud.spanner.pgadapter.error.SQLState; +import com.google.cloud.spanner.pgadapter.parsers.Parser; +import com.google.cloud.spanner.pgadapter.parsers.Parser.FormatCode; +import com.google.cloud.spanner.pgadapter.session.SessionState; +import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TypeDefinition; +import com.google.common.collect.ImmutableMap; +import java.nio.charset.StandardCharsets; +import java.util.Map.Entry; +import java.util.Objects; +import org.apache.commons.text.StringEscapeUtils; +import org.postgresql.core.Oid; + +class LiteralParser { + static final ImmutableMap TYPE_NAME_TO_OID_MAPPING = + ImmutableMap.builder() + .put("unknown", Oid.UNSPECIFIED) + .put("bigint", Oid.INT8) + .put("bigint[]", Oid.INT8_ARRAY) + .put("int8", Oid.INT8) + .put("int8[]", Oid.INT8_ARRAY) + .put("int4", Oid.INT4) + .put("int4[]", Oid.INT4_ARRAY) + .put("int", Oid.INT4) + .put("int[]", Oid.INT4_ARRAY) + .put("integer", Oid.INT4) + .put("integer[]", Oid.INT4_ARRAY) + .put("boolean", Oid.BOOL) + .put("boolean[]", Oid.BOOL_ARRAY) + .put("bool", Oid.BOOL) + .put("bool[]", Oid.BOOL_ARRAY) + .put("bytea", Oid.BYTEA) + .put("bytea[]", Oid.BYTEA_ARRAY) + .put("character varying", Oid.VARCHAR) + .put("character varying[]", Oid.VARCHAR_ARRAY) + .put("varchar", Oid.VARCHAR) + .put("varchar[]", Oid.VARCHAR_ARRAY) + .put("date", Oid.DATE) + .put("date[]", Oid.DATE_ARRAY) + .put("double precision", Oid.FLOAT8) + .put("double precision[]", Oid.FLOAT8_ARRAY) + .put("float8", Oid.FLOAT8) + .put("float8[]", Oid.FLOAT8_ARRAY) + .put("jsonb", Oid.JSONB) + .put("jsonb[]", Oid.JSONB_ARRAY) + .put("numeric", Oid.NUMERIC) + .put("numeric[]", Oid.NUMERIC_ARRAY) + .put("decimal", Oid.NUMERIC) + .put("decimal[]", Oid.NUMERIC_ARRAY) + .put("text", Oid.TEXT) + .put("text[]", Oid.TEXT_ARRAY) + .put("timestamp with time zone", Oid.TIMESTAMPTZ) + .put("timestamp with time zone[]", Oid.TIMESTAMPTZ_ARRAY) + .put("timestamptz", Oid.TIMESTAMPTZ) + .put("timestamptz[]", Oid.TIMESTAMPTZ_ARRAY) + .build(); + + static class Literal { + final String value; + final Integer castToOid; + + static Literal of(String value) { + return new Literal(value, null); + } + + static Literal of(String value, Integer castToOid) { + return new Literal(value, castToOid); + } + + private Literal(String value, Integer castToOid) { + this.value = value; + this.castToOid = castToOid; + } + + String getConvertedValue(SessionState sessionState) { + if (castToOid == null) { + return value; + } + Parser pgParser = + Parser.create( + sessionState, value.getBytes(StandardCharsets.UTF_8), castToOid, FormatCode.TEXT); + return pgParser.stringParse(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Literal)) { + return false; + } + Literal other = (Literal) o; + return Objects.equals(value, other.value) && Objects.equals(castToOid, other.castToOid); + } + + @Override + public int hashCode() { + return Objects.hash(value, castToOid); + } + + @Override + public String toString() { + return value + (castToOid == null ? "" : "::" + oidToDataTypeName(castToOid)); + } + } + + static class QuotedString { + final boolean escaped; + final char quote; + final String rawValue; + private String value; + + QuotedString(boolean escaped, char quote, String rawValue) { + this.escaped = escaped; + this.quote = quote; + this.rawValue = rawValue; + } + + String getValue() { + if (this.value == null) { + this.value = + this.escaped + ? unescapeQuotedStringValue(this.rawValue, this.quote) + : quotedStringValue(this.rawValue, this.quote); + } + return this.value; + } + + static String quotedStringValue(String quotedString, char quoteChar) { + if (quotedString.length() < 2 + || quotedString.charAt(0) != quoteChar + || quotedString.charAt(quotedString.length() - 1) != quoteChar) { + throw PGExceptionFactory.newPGException( + quotedString + " is not a valid string", SQLState.SyntaxError); + } + String doubleQuotes = String.valueOf(quoteChar) + quoteChar; + String singleQuote = String.valueOf(quoteChar); + return quotedString + .substring(1, quotedString.length() - 1) + .replace(doubleQuotes, singleQuote); + } + + static String unescapeQuotedStringValue(String quotedString, char quoteChar) { + if (quotedString.length() < 2 + || quotedString.charAt(0) != quoteChar + || quotedString.charAt(quotedString.length() - 1) != quoteChar) { + throw PGExceptionFactory.newPGException( + quotedString + " is not a valid string", SQLState.SyntaxError); + } + if (quotedString.startsWith(quoteChar + "\\x")) { + throw PGExceptionFactory.newPGException( + "PGAdapter does not support hexadecimal byte values in string literals", + SQLState.SyntaxError); + } + String result = + StringEscapeUtils.unescapeJava(quotedString.substring(1, quotedString.length() - 1)); + String doubleQuotes = String.valueOf(quoteChar) + quoteChar; + String singleQuote = String.valueOf(quoteChar); + return result.replace(doubleQuotes, singleQuote); + } + } + + static class DollarQuotedString { + final String tag; + final String value; + + DollarQuotedString(String tag, String value) { + this.tag = tag; + this.value = value; + } + } + + static QuotedString readSingleQuotedString(SimpleParser parser) { + LiteralParser literalParser = new LiteralParser(parser); + return literalParser.readQuotedString(SINGLE_QUOTE); + } + + static QuotedString readDoubleQuotedString(SimpleParser parser) { + LiteralParser literalParser = new LiteralParser(parser); + return literalParser.readQuotedString(DOUBLE_QUOTE); + } + + static int dataTypeNameToOid(String type) { + SimpleParser parser = new SimpleParser(type); + TypeDefinition typeDefinition = parser.readType(); + Integer oid = + TYPE_NAME_TO_OID_MAPPING.get(typeDefinition.getNameAndArrayBrackets().toLowerCase()); + if (oid != null) { + return oid; + } + throw PGExceptionFactory.newPGException("unknown type name: " + type); + } + + static String oidToDataTypeName(int oid) { + for (Entry entry : TYPE_NAME_TO_OID_MAPPING.entrySet()) { + if (entry.getValue() == oid) { + return entry.getKey(); + } + } + throw PGExceptionFactory.newPGException("unknown oid: " + oid); + } + + static Literal parseSingleConstantLiteralExpression(String expression) { + LiteralParser literalParser = new LiteralParser(new SimpleParser(expression)); + Literal result = literalParser.readConstantLiteralExpression(); + if (literalParser.parser.hasMoreTokens()) { + throw PGExceptionFactory.newPGException("Unexpected tokens in expression: " + expression); + } + return result; + } + + private final SimpleParser parser; + + LiteralParser(SimpleParser parser) { + this.parser = parser; + } + + /** + * Reads a constant literal with a possible type cast. Does not support recursive casts or any + * other constant expressions. That is; the following is supported: + * + *
    + *
  • 'test' + *
  • e'test' + *
  • 100 + *
  • cast('test' as varchar) + *
  • varchar 'test' + *
  • varchar('test') + *
  • 'test'::varchar + *
+ * + *

The following is not supported: + * + *

    + *
  • `100+100` + *
  • `cast('100'::int as varchar)` + *
  • `(varchar '100')::int` + *
+ */ + Literal readConstantLiteralExpression() { + if (parser.eatKeyword("null")) { + return null; + } + boolean cast = false; + TypeDefinition precedingTypeDefinition = null; + boolean functionStyleTypeCast = false; + if (parser.eatKeyword("cast")) { + if (!parser.eatToken("(")) { + throw PGExceptionFactory.newPGException( + "Missing opening parentheses for CAST: " + parser.getSql(), SQLState.SyntaxError); + } + cast = true; + } else { + for (String typeName : TYPE_NAME_TO_OID_MAPPING.keySet()) { + if (parser.peekKeyword(typeName)) { + precedingTypeDefinition = parser.readType(); + functionStyleTypeCast = parser.eatToken("("); + break; + } + } + } + String value = readLiteralValue(precedingTypeDefinition != null); + TypeDefinition typeDefinition = null; + if (precedingTypeDefinition != null) { + if (functionStyleTypeCast) { + if (!parser.eatToken(")")) { + throw PGExceptionFactory.newPGException( + String.format( + "Missing closing parentheses for %s: %s", + precedingTypeDefinition.name, parser.getSql()), + SQLState.SyntaxError); + } + } + typeDefinition = precedingTypeDefinition; + } else if (cast) { + typeDefinition = eatAsType(); + if (!parser.eatToken(")")) { + throw PGExceptionFactory.newPGException( + String.format("Missing closing parentheses for CAST: %s", parser.getSql()), + SQLState.SyntaxError); + } + } else { + // Check for the '::' cast operator. + if (parser.eatToken("::")) { + typeDefinition = parser.readType(); + } + } + if (typeDefinition != null) { + return new Literal(value, dataTypeNameToOid(typeDefinition.name)); + } + return new Literal(value, null); + } + + String readLiteralValue(boolean mustBeQuoted) { + parser.skipWhitespaces(); + if (parser.getPos() >= parser.getSql().length()) { + throw PGExceptionFactory.newPGException("Invalid literal: " + parser.getSql()); + } + if (parser.peekCharsIgnoreCase("'") + || parser.peekCharsIgnoreCase("e'") + || parser.peekCharsIgnoreCase("b'") + || parser.peekCharsIgnoreCase("x'") + || parser.peekCharsIgnoreCase("u&'")) { + QuotedString quotedString = readQuotedString('\''); + return quotedString.getValue(); + } else if (parser.getSql().charAt(parser.getPos()) == DOLLAR + && parser.getSql().length() > (parser.getPos() + 1) + && (parser.getSql().charAt(parser.getPos() + 1) == DOLLAR + || isValidIdentifierFirstChar(parser.getSql().charAt(parser.getPos() + 1))) + && parser.getSql().indexOf(DOLLAR, parser.getPos() + 1) > -1) { + DollarQuotedString dollarQuotedString = readDollarQuotedString(); + return dollarQuotedString.value; + } else if (mustBeQuoted) { + throw PGExceptionFactory.newPGException( + "Expression must be a quoted string", SQLState.SyntaxError); + } else { + return readNumericLiteralValue(); + } + } + + String readNumericLiteralValue() { + int startPos = parser.getPos(); + + // Accept a leading sign. + if (currentChar() == '+') { + incPos(); + } else if (currentChar() == '-') { + incPos(); + } + // Note that this loop will continue as long as the literal contains valid characters for a + // numeric literal, or other characters that would otherwise not indicate the end of the token. + // That means that this method may return something like "100abc". This will then fail at a + // later moment when it is being converted to an actual number. + while (isValidPos() + && (currentChar() == '.' + || Character.isDigit(currentChar()) + || currentChar() == '+' + || currentChar() == '-' + || isValidIdentifierChar(currentChar()))) { + if (currentChar() == '+' || currentChar() == '-') { + if (!(prevChar() == 'e' || prevChar() == 'E')) { + break; + } + } + incPos(); + } + return parser.getSql().substring(startPos, parser.getPos()); + } + + private char currentChar() { + return parser.getSql().charAt(parser.getPos()); + } + + private char prevChar() { + if (parser.getPos() == 0) { + return 0; + } + return parser.getSql().charAt(parser.getPos() - 1); + } + + private boolean isValidPos() { + return parser.getPos() < parser.getSql().length(); + } + + private void incPos() { + parser.setPos(parser.getPos() + 1); + } + + private TypeDefinition eatAsType() { + if (!parser.eatKeyword("as")) { + throw PGExceptionFactory.newPGException("Missing AS keyword in CAST: " + parser.getSql()); + } + return parser.readType(); + } + + QuotedString readQuotedString(char quote) { + parser.skipWhitespaces(); + if (parser.getPos() >= parser.getSql().length()) { + throw PGExceptionFactory.newPGException("Unexpected end of expression", SQLState.SyntaxError); + } + boolean escaped = parser.eatToken("e"); + if (parser.getSql().charAt(parser.getPos()) != quote) { + throw PGExceptionFactory.newPGException( + "Invalid quote character: " + parser.getSql().charAt(parser.getPos()), + SQLState.SyntaxError); + } + int startPos = parser.getPos(); + if (parser.skipQuotedString(escaped)) { + return new QuotedString(escaped, quote, parser.getSql().substring(startPos, parser.getPos())); + } + throw PGExceptionFactory.newPGException("Missing end quote character", SQLState.SyntaxError); + } + + DollarQuotedString readDollarQuotedString() { + int startPos = parser.getPos(); + if (!parser.eatToken("$")) { + throw PGExceptionFactory.newPGException("Missing expected token: '$'"); + } + String tag = parser.parseDollarQuotedTag(); + parser.setPos(startPos); + if (!parser.skipDollarQuotedString()) { + throw PGExceptionFactory.newPGException("Invalid dollar-quoted string: " + parser.getSql()); + } + String rawValue = parser.getSql().substring(startPos, parser.getPos()); + String value = rawValue.substring(tag.length() + 2, rawValue.length() - tag.length() - 2); + + return new DollarQuotedString(tag, value); + } +} diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/PrepareStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/PrepareStatement.java index bd69f4617..96ec59f27 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/PrepareStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/PrepareStatement.java @@ -26,65 +26,19 @@ import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; import com.google.cloud.spanner.pgadapter.statements.BackendConnection.NoResult; import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName; -import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TypeDefinition; import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage.ManuallyCreatedToken; import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage.PreparedType; import com.google.cloud.spanner.pgadapter.wireprotocol.DescribeMessage; import com.google.cloud.spanner.pgadapter.wireprotocol.ParseMessage; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.Futures; import java.util.List; import java.util.concurrent.Future; import java.util.stream.Collectors; -import org.postgresql.core.Oid; @InternalApi public class PrepareStatement extends IntermediatePortalStatement { - private static final ImmutableMap TYPE_NAME_TO_OID_MAPPING = - ImmutableMap.builder() - .put("unknown", Oid.UNSPECIFIED) - .put("bigint", Oid.INT8) - .put("bigint[]", Oid.INT8_ARRAY) - .put("int8", Oid.INT8) - .put("int8[]", Oid.INT8_ARRAY) - .put("int4", Oid.INT4) - .put("int4[]", Oid.INT4_ARRAY) - .put("int", Oid.INT4) - .put("int[]", Oid.INT4_ARRAY) - .put("integer", Oid.INT4) - .put("integer[]", Oid.INT4_ARRAY) - .put("boolean", Oid.BOOL) - .put("boolean[]", Oid.BOOL_ARRAY) - .put("bool", Oid.BOOL) - .put("bool[]", Oid.BOOL_ARRAY) - .put("bytea", Oid.BYTEA) - .put("bytea[]", Oid.BYTEA_ARRAY) - .put("character varying", Oid.VARCHAR) - .put("character varying[]", Oid.VARCHAR_ARRAY) - .put("varchar", Oid.VARCHAR) - .put("varchar[]", Oid.VARCHAR_ARRAY) - .put("date", Oid.DATE) - .put("date[]", Oid.DATE_ARRAY) - .put("double precision", Oid.FLOAT8) - .put("double precision[]", Oid.FLOAT8_ARRAY) - .put("float8", Oid.FLOAT8) - .put("float8[]", Oid.FLOAT8_ARRAY) - .put("jsonb", Oid.JSONB) - .put("jsonb[]", Oid.JSONB_ARRAY) - .put("numeric", Oid.NUMERIC) - .put("numeric[]", Oid.NUMERIC_ARRAY) - .put("decimal", Oid.NUMERIC) - .put("decimal[]", Oid.NUMERIC_ARRAY) - .put("text", Oid.TEXT) - .put("text[]", Oid.TEXT_ARRAY) - .put("timestamp with time zone", Oid.TIMESTAMPTZ) - .put("timestamp with time zone[]", Oid.TIMESTAMPTZ_ARRAY) - .put("timestamptz", Oid.TIMESTAMPTZ) - .put("timestamptz[]", Oid.TIMESTAMPTZ_ARRAY) - .build(); - static final class ParsedPreparedStatement { final String name; final int[] dataTypes; @@ -198,7 +152,7 @@ static ParsedPreparedStatement parse(String sql) { } dataTypesBuilder.addAll( dataTypesNames.stream() - .map(PrepareStatement::dataTypeNameToOid) + .map(LiteralParser::dataTypeNameToOid) .collect(Collectors.toList())); } if (!parser.eatKeyword("as")) { @@ -209,15 +163,4 @@ static ParsedPreparedStatement parse(String sql) { dataTypesBuilder.build().stream().mapToInt(i -> i).toArray(), parser.getSql().substring(parser.getPos()).trim()); } - - static int dataTypeNameToOid(String type) { - SimpleParser parser = new SimpleParser(type); - TypeDefinition typeDefinition = parser.readType(); - Integer oid = - TYPE_NAME_TO_OID_MAPPING.get(typeDefinition.getNameAndArrayBrackets().toLowerCase()); - if (oid != null) { - return oid; - } - throw PGExceptionFactory.newPGException("unknown type name: " + type); - } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/SimpleParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/SimpleParser.java index 418ba7ce0..c333724b5 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/SimpleParser.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/SimpleParser.java @@ -19,6 +19,7 @@ import com.google.cloud.spanner.Value; import com.google.cloud.spanner.pgadapter.error.PGExceptionFactory; import com.google.cloud.spanner.pgadapter.error.SQLState; +import com.google.cloud.spanner.pgadapter.statements.LiteralParser.QuotedString; import com.google.common.base.Preconditions; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; @@ -29,18 +30,17 @@ import java.util.Objects; import javax.annotation.Nonnull; import javax.annotation.Nullable; -import org.apache.commons.text.StringEscapeUtils; /** A very simple parser that can interpret SQL statements to find specific parts in the string. */ @InternalApi public class SimpleParser { private static final char STATEMENT_DELIMITER = ';'; - private static final char SINGLE_QUOTE = '\''; - private static final char DOUBLE_QUOTE = '"'; + static final char SINGLE_QUOTE = '\''; + static final char DOUBLE_QUOTE = '"'; private static final char HYPHEN = '-'; private static final char SLASH = '/'; private static final char ASTERISK = '*'; - private static final char DOLLAR = '$'; + static final char DOLLAR = '$'; /** Name of table or index. */ static class TableOrIndexName { @@ -118,62 +118,6 @@ String getNameAndArrayBrackets() { } } - static class QuotedString { - final boolean escaped; - final char quote; - final String rawValue; - private String value; - - QuotedString(boolean escaped, char quote, String rawValue) { - this.escaped = escaped; - this.quote = quote; - this.rawValue = rawValue; - } - - String getValue() { - if (this.value == null) { - this.value = - this.escaped - ? unescapeQuotedStringValue(this.rawValue, this.quote) - : quotedStringValue(this.rawValue, this.quote); - } - return this.value; - } - - static String quotedStringValue(String quotedString, char quoteChar) { - if (quotedString.length() < 2 - || quotedString.charAt(0) != quoteChar - || quotedString.charAt(quotedString.length() - 1) != quoteChar) { - throw PGExceptionFactory.newPGException( - quotedString + " is not a valid string", SQLState.SyntaxError); - } - String doubleQuotes = String.valueOf(quoteChar) + quoteChar; - String singleQuote = String.valueOf(quoteChar); - return quotedString - .substring(1, quotedString.length() - 1) - .replace(doubleQuotes, singleQuote); - } - - static String unescapeQuotedStringValue(String quotedString, char quoteChar) { - if (quotedString.length() < 2 - || quotedString.charAt(0) != quoteChar - || quotedString.charAt(quotedString.length() - 1) != quoteChar) { - throw PGExceptionFactory.newPGException( - quotedString + " is not a valid string", SQLState.SyntaxError); - } - if (quotedString.startsWith(quoteChar + "\\x")) { - throw PGExceptionFactory.newPGException( - "PGAdapter does not support hexadecimal byte values in string literals", - SQLState.SyntaxError); - } - String result = - StringEscapeUtils.unescapeJava(quotedString.substring(1, quotedString.length() - 1)); - String doubleQuotes = String.valueOf(quoteChar) + quoteChar; - String singleQuote = String.valueOf(quoteChar); - return result.replace(doubleQuotes, singleQuote); - } - } - static String unquoteOrFoldIdentifier(String identifier) { if (Strings.isNullOrEmpty(identifier)) { return null; @@ -412,6 +356,9 @@ List expressionListToColumnNames(String name, List exp @Nonnull String readKeyword() { skipWhitespaces(); + if (!isValidStartOfKeyword(pos)) { + return ""; + } int startPos = pos; while (pos < sql.length() && !isValidEndOfKeyword(pos)) { pos++; @@ -517,11 +464,11 @@ String readIdentifierPart() { return sql.substring(start); } - private boolean isValidIdentifierFirstChar(char c) { + static boolean isValidIdentifierFirstChar(char c) { return Character.isLetter(c) || c == '_'; } - private boolean isValidIdentifierChar(char c) { + static boolean isValidIdentifierChar(char c) { return isValidIdentifierFirstChar(c) || Character.isDigit(c) || c == '$'; } @@ -537,6 +484,15 @@ boolean peek(boolean skipWhitespaceBefore, boolean requireWhitespaceAfter, Strin return internalEat(keyword, skipWhitespaceBefore, requireWhitespaceAfter, false); } + boolean peekCharsIgnoreCase(String characters) { + Preconditions.checkNotNull(characters); + Preconditions.checkArgument(characters.length() > 0); + if (characters.length() > sql.length() - pos) { + return false; + } + return sql.substring(pos, pos + characters.length()).equalsIgnoreCase(characters); + } + boolean eatKeyword(String... keywords) { return eat(true, true, keywords); } @@ -653,6 +609,13 @@ private boolean internalEat( return false; } + private boolean isValidStartOfKeyword(int index) { + if (sql.length() == index) { + return false; + } + return isValidIdentifierFirstChar(sql.charAt(index)); + } + private boolean isValidEndOfKeyword(int index) { if (sql.length() == index) { return true; @@ -664,11 +627,22 @@ boolean skipCommentsAndLiterals() { if (pos >= sql.length()) { return true; } - if ((sql.charAt(pos) == 'e' || sql.charAt(pos) == 'E') + if ((sql.charAt(pos) == 'e' + || sql.charAt(pos) == 'E' + || sql.charAt(pos) == 'b' + || sql.charAt(pos) == 'B' + || sql.charAt(pos) == 'x' + || sql.charAt(pos) == 'X') && sql.length() > (pos + 1) && sql.charAt(pos + 1) == '\'') { pos++; return skipQuotedString(true); + } else if (sql.length() > (pos + 2) + && (sql.charAt(pos) == 'U' || sql.charAt(pos) == 'u') + && sql.charAt(pos + 1) == '&' + && sql.charAt(pos + 2) == '\'') { + pos += 2; + return skipQuotedString(false); } else if (sql.charAt(pos) == SINGLE_QUOTE || sql.charAt(pos) == DOUBLE_QUOTE) { return skipQuotedString(false); } else if (sql.charAt(pos) == HYPHEN @@ -690,28 +664,11 @@ boolean skipCommentsAndLiterals() { } QuotedString readSingleQuotedString() { - return readQuotedString(SINGLE_QUOTE); + return LiteralParser.readSingleQuotedString(this); } QuotedString readDoubleQuotedString() { - return readQuotedString(DOUBLE_QUOTE); - } - - QuotedString readQuotedString(char quote) { - skipWhitespaces(); - if (pos >= sql.length()) { - throw PGExceptionFactory.newPGException("Unexpected end of expression", SQLState.SyntaxError); - } - boolean escaped = eatToken("e"); - if (sql.charAt(pos) != quote) { - throw PGExceptionFactory.newPGException( - "Invalid quote character: " + sql.charAt(pos), SQLState.SyntaxError); - } - int startPos = pos; - if (skipQuotedString(escaped)) { - return new QuotedString(escaped, quote, sql.substring(startPos, pos)); - } - throw PGExceptionFactory.newPGException("Missing end quote character", SQLState.SyntaxError); + return LiteralParser.readDoubleQuotedString(this); } boolean skipQuotedString(boolean escaped) { @@ -792,6 +749,7 @@ boolean skipMultiLineComment() { } String parseDollarQuotedTag() { + int originalPos = pos; // Look ahead to the next dollar sign (if any). Everything in between is the quote tag. StringBuilder tag = new StringBuilder(); while (pos < sql.length()) { diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParser.java index d39e1bd71..5a348efb8 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParser.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParser.java @@ -29,6 +29,7 @@ import com.google.cloud.spanner.pgadapter.parsers.NumericParser; import com.google.cloud.spanner.pgadapter.parsers.StringParser; import com.google.cloud.spanner.pgadapter.parsers.TimestampParser; +import com.google.cloud.spanner.pgadapter.session.SessionState; import com.google.common.base.Preconditions; import java.io.BufferedInputStream; import java.io.DataInputStream; @@ -236,13 +237,13 @@ public boolean hasColumnNames() { } @Override - public Value getValue(Type type, String columnName) { + public Value getValue(SessionState sessionState, Type type, String columnName) { // The binary copy format does not include any column name headers or any type information. throw new UnsupportedOperationException(); } @Override - public Value getValue(Type type, int columnIndex) { + public Value getValue(SessionState sessionState, Type type, int columnIndex) { Preconditions.checkArgument( columnIndex >= 0 && columnIndex < numColumns(), "columnIndex must be >= 0 && < numColumns"); diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyRecord.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyRecord.java index e6c252aa4..6289b5958 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyRecord.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/CopyRecord.java @@ -17,6 +17,7 @@ import com.google.api.core.InternalApi; import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Value; +import com.google.cloud.spanner.pgadapter.session.SessionState; /** * {@link CopyRecord} is a common interface for COPY data records that are produced by a parser for @@ -29,8 +30,8 @@ public interface CopyRecord { int numColumns(); /** - * Returns true if the copy record has column names. The {@link #getValue(Type, String)} method - * can only be used for records that have column names. + * Returns true if the copy record has column names. The {@link #getValue(SessionState, Type, + * String)} method can only be used for records that have column names. */ boolean hasColumnNames(); @@ -40,12 +41,12 @@ public interface CopyRecord { * where it is being inserted. This method can only be used with records that contains column * names. */ - Value getValue(Type type, String columnName); + Value getValue(SessionState sessionState, Type type, String columnName); /** * Returns the value of the given column as a Cloud Spanner {@link Value} of the given type. This * method is used by a COPY ... FROM ... operation to convert a value to the type of the column * where it is being inserted. This method is supported for all types of {@link CopyRecord}. */ - Value getValue(Type type, int columnIndex); + Value getValue(SessionState sessionState, Type type, int columnIndex); } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParser.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParser.java index df3fdf0a8..9c1d4a11a 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParser.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParser.java @@ -24,6 +24,7 @@ import com.google.cloud.spanner.Value; import com.google.cloud.spanner.pgadapter.parsers.BooleanParser; import com.google.cloud.spanner.pgadapter.parsers.TimestampParser; +import com.google.cloud.spanner.pgadapter.session.SessionState; import com.google.common.collect.Iterators; import java.io.IOException; import java.io.InputStreamReader; @@ -94,18 +95,21 @@ public boolean hasColumnNames() { } @Override - public Value getValue(Type type, String columnName) throws SpannerException { + public Value getValue(SessionState sessionState, Type type, String columnName) + throws SpannerException { String recordValue = record.get(columnName); - return getSpannerValue(type, recordValue); + return getSpannerValue(sessionState, type, recordValue); } @Override - public Value getValue(Type type, int columnIndex) throws SpannerException { + public Value getValue(SessionState sessionState, Type type, int columnIndex) + throws SpannerException { String recordValue = record.get(columnIndex); - return getSpannerValue(type, recordValue); + return getSpannerValue(sessionState, type, recordValue); } - static Value getSpannerValue(Type type, String recordValue) throws SpannerException { + static Value getSpannerValue(SessionState sessionState, Type type, String recordValue) + throws SpannerException { try { switch (type.getCode()) { case STRING: @@ -134,7 +138,7 @@ static Value getSpannerValue(Type type, String recordValue) throws SpannerExcept return Value.date(recordValue == null ? null : Date.parseDate(recordValue)); case TIMESTAMP: Timestamp timestamp = - recordValue == null ? null : TimestampParser.toTimestamp(recordValue); + recordValue == null ? null : TimestampParser.toTimestamp(recordValue, sessionState); return Value.timestamp(timestamp); default: SpannerException spannerException = diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java index 2dc290fe3..88ce02e0f 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java @@ -524,8 +524,8 @@ Mutation buildMutation(CopyRecord record) { Type columnType = this.tableColumns.get(columnName); Value value = record.hasColumnNames() - ? record.getValue(columnType, columnName) - : record.getValue(columnType, index); + ? record.getValue(copySettings.getSessionState(), columnType, columnName) + : record.getValue(copySettings.getSessionState(), columnType, index); builder.set(columnName).to(value); index++; } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/CopyOutMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/CopyOutMockServerTest.java index 15373904b..cc4787c0f 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/CopyOutMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/CopyOutMockServerTest.java @@ -20,6 +20,8 @@ import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assume.assumeTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import com.google.cloud.ByteArray; import com.google.cloud.Date; @@ -33,6 +35,7 @@ import com.google.cloud.spanner.Value; import com.google.cloud.spanner.connection.RandomResultSetGenerator; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; +import com.google.cloud.spanner.pgadapter.session.SessionState; import com.google.cloud.spanner.pgadapter.statements.CopyStatement.Format; import com.google.cloud.spanner.pgadapter.utils.CopyInParser; import com.google.cloud.spanner.pgadapter.utils.CopyRecord; @@ -58,6 +61,7 @@ import java.sql.DriverManager; import java.sql.ResultSet; import java.sql.SQLException; +import java.time.ZoneId; import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -469,18 +473,22 @@ public void testCopyOutBinaryPsql() throws Exception { CopyRecord record = iterator.next(); assertFalse(iterator.hasNext()); - assertEquals(Value.int64(1L), record.getValue(Type.int64(), 0)); - assertEquals(Value.bool(true), record.getValue(Type.bool(), 1)); - assertEquals(Value.bytes(ByteArray.copyFrom("test")), record.getValue(Type.bytes(), 2)); - assertEquals(Value.float64(3.14), record.getValue(Type.float64(), 3)); - assertEquals(Value.int64(100L), record.getValue(Type.int64(), 4)); - assertEquals(Value.pgNumeric("6.626"), record.getValue(Type.pgNumeric(), 5)); + SessionState sessionState = mock(SessionState.class); + when(sessionState.getTimezone()).thenReturn(ZoneId.systemDefault()); + assertEquals(Value.int64(1L), record.getValue(sessionState, Type.int64(), 0)); + assertEquals(Value.bool(true), record.getValue(sessionState, Type.bool(), 1)); + assertEquals( + Value.bytes(ByteArray.copyFrom("test")), record.getValue(sessionState, Type.bytes(), 2)); + assertEquals(Value.float64(3.14), record.getValue(sessionState, Type.float64(), 3)); + assertEquals(Value.int64(100L), record.getValue(sessionState, Type.int64(), 4)); + assertEquals(Value.pgNumeric("6.626"), record.getValue(sessionState, Type.pgNumeric(), 5)); // Note: The binary format truncates timestamptz value to microsecond precision. assertEquals( Value.timestamp(Timestamp.parseTimestamp("2022-02-16T13:18:02.123456000Z")), - record.getValue(Type.timestamp(), 6)); - assertEquals(Value.date(Date.parseDate("2022-03-29")), record.getValue(Type.date(), 7)); - assertEquals(Value.string("test"), record.getValue(Type.string(), 8)); + record.getValue(sessionState, Type.timestamp(), 6)); + assertEquals( + Value.date(Date.parseDate("2022-03-29")), record.getValue(sessionState, Type.date(), 7)); + assertEquals(Value.string("test"), record.getValue(sessionState, Type.string(), 8)); } @Test @@ -521,10 +529,11 @@ public void testCopyOutNullsBinaryPsql() throws Exception { CopyRecord record = iterator.next(); assertFalse(iterator.hasNext()); + SessionState sessionState = mock(SessionState.class); for (int col = 0; col < record.numColumns(); col++) { // Note: Null values in a COPY BINARY stream are untyped, so it does not matter what type we // specify when getting the value. - assertTrue(record.getValue(Type.string(), col).isNull()); + assertTrue(record.getValue(sessionState, Type.string(), col).isNull()); } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/parsers/TimestampParserTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/parsers/TimestampParserTest.java index 0f3492a9b..ef1710208 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/parsers/TimestampParserTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/parsers/TimestampParserTest.java @@ -110,4 +110,14 @@ public void testStringParse() { new TimestampParser( "foo".getBytes(StandardCharsets.UTF_8), FormatCode.TEXT, sessionState)); } + + @Test + public void testTextToTimestamp() { + SessionState sessionState = mock(SessionState.class); + when(sessionState.getTimezone()).thenReturn(ZoneId.of("-09:00")); + + assertEquals( + Timestamp.parseTimestamp("2022-10-09T19:09:18Z"), + TimestampParser.toTimestamp("2022-10-09 10:09:18", sessionState)); + } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/statements/ExecuteStatementTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/statements/ExecuteStatementTest.java index d07bd8ea2..dc86558ae 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/statements/ExecuteStatementTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/statements/ExecuteStatementTest.java @@ -29,7 +29,9 @@ import com.google.cloud.spanner.pgadapter.error.PGException; import com.google.cloud.spanner.pgadapter.metadata.ConnectionMetadata; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; +import com.google.cloud.spanner.pgadapter.session.SessionState; import java.nio.charset.StandardCharsets; +import java.time.ZoneId; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -40,7 +42,15 @@ public class ExecuteStatementTest { @Test public void testGetStatementType() { ConnectionHandler connectionHandler = mock(ConnectionHandler.class); + ExtendedQueryProtocolHandler extendedQueryProtocolHandler = + mock(ExtendedQueryProtocolHandler.class); + BackendConnection backendConnection = mock(BackendConnection.class); + SessionState sessionState = mock(SessionState.class); when(connectionHandler.getConnectionMetadata()).thenReturn(mock(ConnectionMetadata.class)); + when(connectionHandler.getExtendedQueryProtocolHandler()) + .thenReturn(extendedQueryProtocolHandler); + when(extendedQueryProtocolHandler.getBackendConnection()).thenReturn(backendConnection); + when(backendConnection.getSessionState()).thenReturn(sessionState); assertEquals( StatementType.CLIENT_SIDE, new ExecuteStatement( @@ -55,32 +65,52 @@ public void testGetStatementType() { @Test public void testParse() { - assertEquals("foo", parse("execute foo").name); - assertEquals("foo", parse("execute FOO").name); - assertEquals("foo", parse("execute\tfoo").name); - assertEquals("foo", parse("execute\nfoo").name); - assertEquals("foo", parse("execute/*comment*/foo").name); - assertEquals("foo", parse("execute \"foo\"").name); - assertEquals("Foo", parse("execute \"Foo\"").name); - assertEquals("foo", parse("execute foo (1)").name); - assertEquals("foo", parse("execute foo (1, 'test')").name); + SessionState sessionState = mock(SessionState.class); + when(sessionState.getTimezone()).thenReturn(ZoneId.of("+01:00")); + + assertEquals("foo", parse("execute foo", sessionState).name); + assertEquals("foo", parse("execute FOO", sessionState).name); + assertEquals("foo", parse("execute\tfoo", sessionState).name); + assertEquals("foo", parse("execute\nfoo", sessionState).name); + assertEquals("foo", parse("execute/*comment*/foo", sessionState).name); + assertEquals("foo", parse("execute \"foo\"", sessionState).name); + assertEquals("Foo", parse("execute \"Foo\"", sessionState).name); + assertEquals("foo", parse("execute foo (1)", sessionState).name); + assertEquals("foo", parse("execute foo (1, 'test')", sessionState).name); assertArrayEquals( - new byte[][] {param("1"), param("test")}, parse("execute foo (1, 'test')").parameters); + new byte[][] {param("1"), param("test")}, + parse("execute foo (1, 'test')", sessionState).parameters); assertArrayEquals( new byte[][] {param("3.14"), param("\\x55aa")}, - parse("execute foo (3.14, '\\x55aa')").parameters); + parse("execute foo (3.14, '\\x55aa')", sessionState).parameters); assertArrayEquals( new byte[][] {null, param("2000-01-01")}, - parse("execute foo (null, '2000-01-01')").parameters); + parse("execute foo (null, '2000-01-01')", sessionState).parameters); + assertArrayEquals( + new byte[][] {param("1"), param("test")}, + parse("execute foo (cast(1 as int), varchar 'test')", sessionState).parameters); + assertArrayEquals( + new byte[][] {param("1"), param("test")}, + parse("execute foo (1::bigint, varchar 'test')", sessionState).parameters); + assertArrayEquals( + new byte[][] {param("1"), param("test")}, + parse("execute foo (1.0::bigint, e'test')", sessionState).parameters); + assertArrayEquals( + new byte[][] {param("1"), param("2022-10-09")}, + parse("execute foo (1.0::int, '2022-10-09'::date)", sessionState).parameters); + assertArrayEquals( + new byte[][] {param("1"), param("2022-10-09 10:09:18+01")}, + parse("execute foo (1.0::int4, '2022-10-09 10:09:18'::timestamptz)", sessionState) + .parameters); - assertThrows(PGException.class, () -> parse("foo")); - assertThrows(PGException.class, () -> parse("execute")); - assertThrows(PGException.class, () -> parse("execute foo bar")); - assertThrows(PGException.class, () -> parse("execute foo.bar")); - assertThrows(PGException.class, () -> parse("execute foo ()")); - assertThrows(PGException.class, () -> parse("execute foo (1) bar")); - assertThrows(PGException.class, () -> parse("execute foo (1")); + assertThrows(PGException.class, () -> parse("foo", sessionState)); + assertThrows(PGException.class, () -> parse("execute", sessionState)); + assertThrows(PGException.class, () -> parse("execute foo bar", sessionState)); + assertThrows(PGException.class, () -> parse("execute foo.bar", sessionState)); + assertThrows(PGException.class, () -> parse("execute foo ()", sessionState)); + assertThrows(PGException.class, () -> parse("execute foo (1) bar", sessionState)); + assertThrows(PGException.class, () -> parse("execute foo (1", sessionState)); } static byte[] param(String value) { diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/statements/LiteralParserTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/statements/LiteralParserTest.java new file mode 100644 index 000000000..214b1ef63 --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/pgadapter/statements/LiteralParserTest.java @@ -0,0 +1,186 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.spanner.pgadapter.statements; + +import static com.google.cloud.spanner.pgadapter.statements.LiteralParser.dataTypeNameToOid; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import com.google.cloud.spanner.pgadapter.error.PGException; +import com.google.cloud.spanner.pgadapter.statements.LiteralParser.Literal; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.postgresql.core.Oid; + +@RunWith(JUnit4.class) +public class LiteralParserTest { + + @Test + public void testDataTypeNameToOid() { + assertEquals(Oid.UNSPECIFIED, dataTypeNameToOid("unknown")); + assertEquals(Oid.INT8, dataTypeNameToOid("bigint")); + assertEquals(Oid.INT8_ARRAY, dataTypeNameToOid("bigint[]")); + assertEquals(Oid.INT8, dataTypeNameToOid("int8")); + assertEquals(Oid.INT8_ARRAY, dataTypeNameToOid("int8[]")); + assertEquals(Oid.INT4, dataTypeNameToOid("int")); + assertEquals(Oid.INT4_ARRAY, dataTypeNameToOid("int[]")); + assertEquals(Oid.INT4, dataTypeNameToOid("int4")); + assertEquals(Oid.INT4_ARRAY, dataTypeNameToOid("int4[]")); + assertEquals(Oid.INT4, dataTypeNameToOid("integer")); + assertEquals(Oid.INT4_ARRAY, dataTypeNameToOid("integer[]")); + assertEquals(Oid.BOOL, dataTypeNameToOid("boolean")); + assertEquals(Oid.BOOL_ARRAY, dataTypeNameToOid("boolean[]")); + assertEquals(Oid.BOOL, dataTypeNameToOid("bool")); + assertEquals(Oid.BOOL_ARRAY, dataTypeNameToOid("bool[]")); + assertEquals(Oid.DATE, dataTypeNameToOid("date")); + assertEquals(Oid.DATE_ARRAY, dataTypeNameToOid("date[]")); + assertEquals(Oid.FLOAT8, dataTypeNameToOid("double precision")); + assertEquals(Oid.FLOAT8_ARRAY, dataTypeNameToOid("double precision[]")); + assertEquals(Oid.FLOAT8, dataTypeNameToOid("float8")); + assertEquals(Oid.FLOAT8_ARRAY, dataTypeNameToOid("float8[]")); + assertEquals(Oid.JSONB, dataTypeNameToOid("jsonb")); + assertEquals(Oid.JSONB_ARRAY, dataTypeNameToOid("jsonb[]")); + assertEquals(Oid.NUMERIC, dataTypeNameToOid("numeric")); + assertEquals(Oid.NUMERIC_ARRAY, dataTypeNameToOid("numeric[]")); + assertEquals(Oid.NUMERIC, dataTypeNameToOid("decimal")); + assertEquals(Oid.NUMERIC_ARRAY, dataTypeNameToOid("decimal[]")); + assertEquals(Oid.NUMERIC, dataTypeNameToOid("numeric(1,1)")); + assertEquals(Oid.NUMERIC_ARRAY, dataTypeNameToOid("numeric(2, 1)[]")); + + assertEquals(Oid.VARCHAR, dataTypeNameToOid("character varying")); + assertEquals(Oid.VARCHAR, dataTypeNameToOid("varchar")); + assertEquals(Oid.VARCHAR, dataTypeNameToOid("character varying(100)")); + assertEquals(Oid.VARCHAR, dataTypeNameToOid("varchar(100)")); + assertEquals(Oid.VARCHAR, dataTypeNameToOid("character varying (100)")); + assertEquals(Oid.VARCHAR, dataTypeNameToOid("varchar (100)")); + assertEquals(Oid.VARCHAR, dataTypeNameToOid("character varying ( 100 ) \t")); + assertEquals(Oid.VARCHAR, dataTypeNameToOid("varchar\t(100) \n")); + + assertEquals(Oid.VARCHAR_ARRAY, dataTypeNameToOid("character varying[]")); + assertEquals(Oid.VARCHAR_ARRAY, dataTypeNameToOid("varchar[]")); + assertEquals(Oid.VARCHAR_ARRAY, dataTypeNameToOid("character varying(100)[]")); + assertEquals(Oid.VARCHAR_ARRAY, dataTypeNameToOid("varchar(100)[]")); + assertEquals(Oid.VARCHAR_ARRAY, dataTypeNameToOid("character varying (100) []")); + assertEquals(Oid.VARCHAR_ARRAY, dataTypeNameToOid("varchar (100) []")); + assertEquals(Oid.VARCHAR_ARRAY, dataTypeNameToOid("character varying ( 100 ) \t[\n]")); + assertEquals(Oid.VARCHAR_ARRAY, dataTypeNameToOid("varchar\t(100) \n[ ]")); + + assertThrows(PGException.class, () -> dataTypeNameToOid("invalid_type")); + assertThrows(PGException.class, () -> dataTypeNameToOid("varchar(100")); + assertThrows(PGException.class, () -> dataTypeNameToOid("bigint[")); + assertThrows(PGException.class, () -> dataTypeNameToOid("varchar(bar)")); + assertThrows(PGException.class, () -> dataTypeNameToOid("numeric(10, bar)")); + } + + @Test + public void testReadNumericLiteralValue() { + assertEquals("100", readNumericLiteral("100")); + assertEquals("100.1", readNumericLiteral("100.1")); + assertEquals("100", readNumericLiteral("100 ")); + assertEquals("100.1", readNumericLiteral("100.1 ")); + assertEquals("100", readNumericLiteral("100,")); + assertEquals("100.1", readNumericLiteral("100.1,")); + // Note: We don't try to validate the numeric literal here, we only want to detect the end of + // the literal. + assertEquals("100a", readNumericLiteral("100a")); + assertEquals("100.1a", readNumericLiteral("100.1a")); + assertEquals("100", readNumericLiteral("100+100")); + assertEquals("100.1", readNumericLiteral("100.1+100")); + assertEquals("100e100", readNumericLiteral("100e100")); + assertEquals("100.1e100", readNumericLiteral("100.1e100")); + assertEquals("100e+100", readNumericLiteral("100e+100")); + assertEquals("100.1e+100", readNumericLiteral("100.1e+100")); + assertEquals("100e-100", readNumericLiteral("100e-100")); + assertEquals("100.1e-100", readNumericLiteral("100.1e-100")); + + assertEquals("+100", readNumericLiteral("+100")); + assertEquals("-100", readNumericLiteral("-100")); + assertEquals("+", readNumericLiteral("+-100")); + assertEquals("-", readNumericLiteral("-+100")); + assertEquals("", readNumericLiteral("*100")); + assertEquals("", readNumericLiteral("/100")); + } + + @Test + public void testReadLiteralValue() { + assertEquals("100", readLiteralValue("100", false)); + assertEquals("100", readLiteralValue("100 200", false)); + assertEquals("test", readLiteralValue("'test'", false)); + assertEquals("test", readLiteralValue("e'test'", false)); + assertEquals("test", readLiteralValue("E'test'", false)); + assertEquals("test", readLiteralValue("'test'", true)); + assertEquals("test", readLiteralValue("e'test'", true)); + assertEquals("test", readLiteralValue("E'test'", true)); + assertEquals("test", readLiteralValue("'test', 100", false)); + assertEquals("test", readLiteralValue("'test'100", false)); + assertEquals("test", readLiteralValue("'test' 'test'", false)); + assertEquals("test", readLiteralValue("$$test$$", false)); + assertEquals("test", readLiteralValue("$tag$test$tag$", false)); + assertEquals("test", readLiteralValue("$$test$$", true)); + assertEquals("test", readLiteralValue("$tag$test$tag$", true)); + } + + @Test + public void testReadConstantLiteralExpression() { + assertEquals(Literal.of("100"), readConstantLiteralExpression("100")); + assertEquals(Literal.of("100"), readConstantLiteralExpression("100 200")); + assertEquals(Literal.of("test"), readConstantLiteralExpression("'test'")); + assertEquals(Literal.of("test"), readConstantLiteralExpression("e'test'")); + assertEquals(Literal.of("test"), readConstantLiteralExpression("E'test'")); + assertEquals(Literal.of("test"), readConstantLiteralExpression("'test'")); + assertEquals(Literal.of("test"), readConstantLiteralExpression("e'test'")); + assertEquals(Literal.of("test"), readConstantLiteralExpression("E'test'")); + assertEquals(Literal.of("test"), readConstantLiteralExpression("'test', 100")); + assertEquals(Literal.of("test"), readConstantLiteralExpression("'test'100")); + assertEquals(Literal.of("test"), readConstantLiteralExpression("'test' 'test'")); + assertEquals(Literal.of("test"), readConstantLiteralExpression("$$test$$")); + assertEquals(Literal.of("test"), readConstantLiteralExpression("$tag$test$tag$")); + assertEquals(Literal.of("test"), readConstantLiteralExpression("$$test$$")); + assertEquals(Literal.of("test"), readConstantLiteralExpression("$tag$test$tag$")); + + assertEquals( + Literal.of("100", Oid.VARCHAR), readConstantLiteralExpression("cast(100 as varchar)")); + assertEquals( + Literal.of("100", Oid.VARCHAR), readConstantLiteralExpression("cast(100 as varchar(10))")); + assertEquals(Literal.of("100", Oid.VARCHAR), readConstantLiteralExpression("varchar '100'")); + assertEquals( + Literal.of("2022-12-12T20:09:00+01:00", Oid.TIMESTAMPTZ), + readConstantLiteralExpression("timestamptz '2022-12-12T20:09:00+01:00'")); + assertEquals( + Literal.of("test", Oid.VARCHAR), readConstantLiteralExpression("varchar $$test$$")); + assertEquals( + Literal.of("test", Oid.VARCHAR), + readConstantLiteralExpression("cast($tag$test$tag$ as varchar)")); + assertEquals( + Literal.of("2022-12-12", Oid.DATE), + readConstantLiteralExpression("$tag$2022-12-12$tag$::date")); + } + + static String readNumericLiteral(String input) { + LiteralParser literalParser = new LiteralParser(new SimpleParser(input)); + return literalParser.readNumericLiteralValue(); + } + + static String readLiteralValue(String input, boolean mustBeQuoted) { + LiteralParser literalParser = new LiteralParser(new SimpleParser(input)); + return literalParser.readLiteralValue(mustBeQuoted); + } + + static Literal readConstantLiteralExpression(String input) { + LiteralParser literalParser = new LiteralParser(new SimpleParser(input)); + return literalParser.readConstantLiteralExpression(); + } +} diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/statements/PrepareStatementTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/statements/PrepareStatementTest.java index ee0f0124a..8ad3378f8 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/statements/PrepareStatementTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/statements/PrepareStatementTest.java @@ -14,7 +14,6 @@ package com.google.cloud.spanner.pgadapter.statements; -import static com.google.cloud.spanner.pgadapter.statements.PrepareStatement.dataTypeNameToOid; import static com.google.cloud.spanner.pgadapter.statements.PrepareStatement.parse; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @@ -30,63 +29,6 @@ @RunWith(JUnit4.class) public class PrepareStatementTest { - @Test - public void testDataTypeNameToOid() { - assertEquals(Oid.UNSPECIFIED, dataTypeNameToOid("unknown")); - assertEquals(Oid.INT8, dataTypeNameToOid("bigint")); - assertEquals(Oid.INT8_ARRAY, dataTypeNameToOid("bigint[]")); - assertEquals(Oid.INT8, dataTypeNameToOid("int8")); - assertEquals(Oid.INT8_ARRAY, dataTypeNameToOid("int8[]")); - assertEquals(Oid.INT4, dataTypeNameToOid("int")); - assertEquals(Oid.INT4_ARRAY, dataTypeNameToOid("int[]")); - assertEquals(Oid.INT4, dataTypeNameToOid("int4")); - assertEquals(Oid.INT4_ARRAY, dataTypeNameToOid("int4[]")); - assertEquals(Oid.INT4, dataTypeNameToOid("integer")); - assertEquals(Oid.INT4_ARRAY, dataTypeNameToOid("integer[]")); - assertEquals(Oid.BOOL, dataTypeNameToOid("boolean")); - assertEquals(Oid.BOOL_ARRAY, dataTypeNameToOid("boolean[]")); - assertEquals(Oid.BOOL, dataTypeNameToOid("bool")); - assertEquals(Oid.BOOL_ARRAY, dataTypeNameToOid("bool[]")); - assertEquals(Oid.DATE, dataTypeNameToOid("date")); - assertEquals(Oid.DATE_ARRAY, dataTypeNameToOid("date[]")); - assertEquals(Oid.FLOAT8, dataTypeNameToOid("double precision")); - assertEquals(Oid.FLOAT8_ARRAY, dataTypeNameToOid("double precision[]")); - assertEquals(Oid.FLOAT8, dataTypeNameToOid("float8")); - assertEquals(Oid.FLOAT8_ARRAY, dataTypeNameToOid("float8[]")); - assertEquals(Oid.JSONB, dataTypeNameToOid("jsonb")); - assertEquals(Oid.JSONB_ARRAY, dataTypeNameToOid("jsonb[]")); - assertEquals(Oid.NUMERIC, dataTypeNameToOid("numeric")); - assertEquals(Oid.NUMERIC_ARRAY, dataTypeNameToOid("numeric[]")); - assertEquals(Oid.NUMERIC, dataTypeNameToOid("decimal")); - assertEquals(Oid.NUMERIC_ARRAY, dataTypeNameToOid("decimal[]")); - assertEquals(Oid.NUMERIC, dataTypeNameToOid("numeric(1,1)")); - assertEquals(Oid.NUMERIC_ARRAY, dataTypeNameToOid("numeric(2, 1)[]")); - - assertEquals(Oid.VARCHAR, dataTypeNameToOid("character varying")); - assertEquals(Oid.VARCHAR, dataTypeNameToOid("varchar")); - assertEquals(Oid.VARCHAR, dataTypeNameToOid("character varying(100)")); - assertEquals(Oid.VARCHAR, dataTypeNameToOid("varchar(100)")); - assertEquals(Oid.VARCHAR, dataTypeNameToOid("character varying (100)")); - assertEquals(Oid.VARCHAR, dataTypeNameToOid("varchar (100)")); - assertEquals(Oid.VARCHAR, dataTypeNameToOid("character varying ( 100 ) \t")); - assertEquals(Oid.VARCHAR, dataTypeNameToOid("varchar\t(100) \n")); - - assertEquals(Oid.VARCHAR_ARRAY, dataTypeNameToOid("character varying[]")); - assertEquals(Oid.VARCHAR_ARRAY, dataTypeNameToOid("varchar[]")); - assertEquals(Oid.VARCHAR_ARRAY, dataTypeNameToOid("character varying(100)[]")); - assertEquals(Oid.VARCHAR_ARRAY, dataTypeNameToOid("varchar(100)[]")); - assertEquals(Oid.VARCHAR_ARRAY, dataTypeNameToOid("character varying (100) []")); - assertEquals(Oid.VARCHAR_ARRAY, dataTypeNameToOid("varchar (100) []")); - assertEquals(Oid.VARCHAR_ARRAY, dataTypeNameToOid("character varying ( 100 ) \t[\n]")); - assertEquals(Oid.VARCHAR_ARRAY, dataTypeNameToOid("varchar\t(100) \n[ ]")); - - assertThrows(PGException.class, () -> dataTypeNameToOid("invalid_type")); - assertThrows(PGException.class, () -> dataTypeNameToOid("varchar(100")); - assertThrows(PGException.class, () -> dataTypeNameToOid("bigint[")); - assertThrows(PGException.class, () -> dataTypeNameToOid("varchar(bar)")); - assertThrows(PGException.class, () -> dataTypeNameToOid("numeric(10, bar)")); - } - @Test public void testParse() { ParsedPreparedStatement statement = parse("prepare foo as select 1"); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/statements/SimpleParserTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/statements/SimpleParserTest.java index cd04d756e..eed7710c3 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/statements/SimpleParserTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/statements/SimpleParserTest.java @@ -14,7 +14,7 @@ package com.google.cloud.spanner.pgadapter.statements; -import static com.google.cloud.spanner.pgadapter.statements.SimpleParser.QuotedString.unescapeQuotedStringValue; +import static com.google.cloud.spanner.pgadapter.statements.LiteralParser.QuotedString.unescapeQuotedStringValue; import static com.google.cloud.spanner.pgadapter.statements.SimpleParser.parseCommand; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -24,7 +24,7 @@ import static org.junit.Assert.assertTrue; import com.google.cloud.spanner.pgadapter.error.PGException; -import com.google.cloud.spanner.pgadapter.statements.SimpleParser.QuotedString; +import com.google.cloud.spanner.pgadapter.statements.LiteralParser.QuotedString; import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName; import com.google.common.collect.ImmutableList; import java.util.Arrays; @@ -426,16 +426,16 @@ public void testParseCommand() { @Test public void testQuotedString() { - assertEquals("test", new SimpleParser("'test'").readQuotedString('\'').getValue()); - assertEquals("test", new SimpleParser("e'test'").readQuotedString('\'').getValue()); + assertEquals("test", new SimpleParser("'test'").readSingleQuotedString().getValue()); + assertEquals("test", new SimpleParser("e'test'").readSingleQuotedString().getValue()); PGException exception = assertThrows( - PGException.class, () -> new SimpleParser("test").readQuotedString('\'').getValue()); + PGException.class, () -> new SimpleParser("test").readSingleQuotedString().getValue()); assertEquals("Invalid quote character: t", exception.getMessage()); exception = assertThrows( PGException.class, - () -> new SimpleParser("e\"test\"").readQuotedString('\'').getValue()); + () -> new SimpleParser("e\"test\"").readSingleQuotedString().getValue()); assertEquals("Invalid quote character: \"", exception.getMessage()); exception = @@ -445,9 +445,11 @@ public void testQuotedString() { assertThrows(PGException.class, () -> new QuotedString(true, '\'', "test").getValue()); assertEquals("test is not a valid string", exception.getMessage()); - exception = assertThrows(PGException.class, () -> new SimpleParser("'").readQuotedString('\'')); + exception = + assertThrows(PGException.class, () -> new SimpleParser("'").readSingleQuotedString()); assertEquals("Missing end quote character", exception.getMessage()); - exception = assertThrows(PGException.class, () -> new SimpleParser(" ").readQuotedString('\'')); + exception = + assertThrows(PGException.class, () -> new SimpleParser(" ").readSingleQuotedString()); assertEquals("Unexpected end of expression", exception.getMessage()); } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParserTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParserTest.java index 3bb7ff7d5..7a3ab2bf6 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParserTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/utils/BinaryCopyParserTest.java @@ -20,11 +20,13 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Value; +import com.google.cloud.spanner.pgadapter.session.SessionState; import com.google.cloud.spanner.pgadapter.utils.BinaryCopyParser.BinaryField; import com.google.cloud.spanner.pgadapter.utils.BinaryCopyParser.BinaryRecord; import java.io.DataOutputStream; @@ -215,6 +217,7 @@ public void testIteratorNext_WithNoMoreElements() throws IOException { @Test public void testIteratorNext_NullField() throws IOException { + SessionState sessionState = mock(SessionState.class); PipedOutputStream pipedOutputStream = new PipedOutputStream(); BinaryCopyParser parser = new BinaryCopyParser(new PipedInputStream(pipedOutputStream, 256)); @@ -233,12 +236,13 @@ public void testIteratorNext_NullField() throws IOException { Iterator iterator = parser.iterator(); assertTrue(iterator.hasNext()); CopyRecord record = iterator.next(); - assertTrue(record.getValue(Type.int64(), 0).isNull()); + assertTrue(record.getValue(sessionState, Type.int64(), 0).isNull()); assertFalse(iterator.hasNext()); } @Test public void testIteratorNext_GetValue() throws IOException { + SessionState sessionState = mock(SessionState.class); PipedOutputStream pipedOutputStream = new PipedOutputStream(); BinaryCopyParser parser = new BinaryCopyParser(new PipedInputStream(pipedOutputStream, 256)); @@ -258,7 +262,7 @@ public void testIteratorNext_GetValue() throws IOException { Iterator iterator = parser.iterator(); assertTrue(iterator.hasNext()); CopyRecord record = iterator.next(); - assertEquals(Value.int64(100L), record.getValue(Type.int64(), 0)); + assertEquals(Value.int64(100L), record.getValue(sessionState, Type.int64(), 0)); assertFalse(iterator.hasNext()); } @@ -303,6 +307,7 @@ public void testIteratorNext_EndOfFile() throws IOException { @Test public void testIteratorNext_WithOid() throws IOException { + SessionState sessionState = mock(SessionState.class); PipedOutputStream pipedOutputStream = new PipedOutputStream(); BinaryCopyParser parser = new BinaryCopyParser(new PipedInputStream(pipedOutputStream, 256)); @@ -326,8 +331,8 @@ public void testIteratorNext_WithOid() throws IOException { Iterator iterator = parser.iterator(); assertTrue(iterator.hasNext()); CopyRecord record = iterator.next(); - assertEquals(Value.int64(100L), record.getValue(Type.int64(), 0)); - assertEquals(Value.string("test"), record.getValue(Type.string(), 1)); + assertEquals(Value.int64(100L), record.getValue(sessionState, Type.int64(), 0)); + assertEquals(Value.string("test"), record.getValue(sessionState, Type.string(), 1)); assertFalse(iterator.hasNext()); } @@ -354,11 +359,14 @@ public void testIteratorNext_InvalidOidLength() throws IOException { @Test public void testBinaryRecord() { + SessionState sessionState = mock(SessionState.class); BinaryRecord record = new BinaryRecord(new BinaryField[1]); assertFalse(record.hasColumnNames()); assertThrows( - UnsupportedOperationException.class, () -> record.getValue(Type.string(), "column_name")); - assertThrows(IllegalArgumentException.class, () -> record.getValue(Type.string(), 10)); - assertThrows(SpannerException.class, () -> record.getValue(Type.numeric(), 0)); + UnsupportedOperationException.class, + () -> record.getValue(sessionState, Type.string(), "column_name")); + assertThrows( + IllegalArgumentException.class, () -> record.getValue(sessionState, Type.string(), 10)); + assertThrows(SpannerException.class, () -> record.getValue(sessionState, Type.numeric(), 0)); } } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParserTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParserTest.java index e836c5e23..e5f467390 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParserTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/utils/CsvCopyParserTest.java @@ -17,6 +17,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; import com.google.cloud.ByteArray; import com.google.cloud.Date; @@ -25,6 +26,7 @@ import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Type.StructField; import com.google.cloud.spanner.Value; +import com.google.cloud.spanner.pgadapter.session.SessionState; import com.google.cloud.spanner.pgadapter.utils.CsvCopyParser.CsvCopyRecord; import java.io.DataOutputStream; import java.io.IOException; @@ -77,157 +79,224 @@ public void testCanCreateIteratorWithHeader() throws IOException { @Test public void testGetSpannerValueBool() { - assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(Type.bool(), "t")); - assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(Type.bool(), "tr")); - assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(Type.bool(), "tru")); - assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(Type.bool(), "true")); - assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(Type.bool(), "1")); - assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(Type.bool(), "on")); - assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(Type.bool(), "y")); - assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(Type.bool(), "ye")); - assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(Type.bool(), "yes")); - - assertEquals(Value.bool(false), CsvCopyRecord.getSpannerValue(Type.bool(), "f")); - assertEquals(Value.bool(false), CsvCopyRecord.getSpannerValue(Type.bool(), "fa")); - assertEquals(Value.bool(false), CsvCopyRecord.getSpannerValue(Type.bool(), "fal")); - assertEquals(Value.bool(false), CsvCopyRecord.getSpannerValue(Type.bool(), "fals")); - assertEquals(Value.bool(false), CsvCopyRecord.getSpannerValue(Type.bool(), "false")); - assertEquals(Value.bool(false), CsvCopyRecord.getSpannerValue(Type.bool(), "0")); - assertEquals(Value.bool(false), CsvCopyRecord.getSpannerValue(Type.bool(), "off")); - assertEquals(Value.bool(false), CsvCopyRecord.getSpannerValue(Type.bool(), "of")); - assertEquals(Value.bool(false), CsvCopyRecord.getSpannerValue(Type.bool(), "n")); - assertEquals(Value.bool(false), CsvCopyRecord.getSpannerValue(Type.bool(), "no")); - - assertEquals(Value.bool(null), CsvCopyRecord.getSpannerValue(Type.bool(), null)); + SessionState sessionState = mock(SessionState.class); + + assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "t")); + assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "tr")); + assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "tru")); + assertEquals( + Value.bool(true), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "true")); + assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "1")); + assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "on")); + assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "y")); + assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "ye")); + assertEquals(Value.bool(true), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "yes")); + + assertEquals(Value.bool(false), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "f")); + assertEquals(Value.bool(false), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "fa")); + assertEquals( + Value.bool(false), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "fal")); + assertEquals( + Value.bool(false), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "fals")); + assertEquals( + Value.bool(false), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "false")); + assertEquals(Value.bool(false), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "0")); + assertEquals( + Value.bool(false), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "off")); + assertEquals(Value.bool(false), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "of")); + assertEquals(Value.bool(false), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "n")); + assertEquals(Value.bool(false), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "no")); + + assertEquals(Value.bool(null), CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), null)); } @Test public void testGetSpannerValueBytes() { + SessionState sessionState = mock(SessionState.class); + assertEquals( Value.bytes(ByteArray.copyFrom("test")), - CsvCopyRecord.getSpannerValue(Type.bytes(), "\\x74657374")); + CsvCopyRecord.getSpannerValue(sessionState, Type.bytes(), "\\x74657374")); + assertEquals( + Value.bytes(ByteArray.copyFrom("")), + CsvCopyRecord.getSpannerValue(sessionState, Type.bytes(), "\\x")); assertEquals( - Value.bytes(ByteArray.copyFrom("")), CsvCopyRecord.getSpannerValue(Type.bytes(), "\\x")); - assertEquals(Value.bytes(null), CsvCopyRecord.getSpannerValue(Type.bytes(), null)); + Value.bytes(null), CsvCopyRecord.getSpannerValue(sessionState, Type.bytes(), null)); } @Test public void testGetSpannerValueInt64() { - assertEquals(Value.int64(-1L), CsvCopyRecord.getSpannerValue(Type.int64(), "-1")); - assertEquals(Value.int64(1L), CsvCopyRecord.getSpannerValue(Type.int64(), "1")); - assertEquals(Value.int64(0L), CsvCopyRecord.getSpannerValue(Type.int64(), "0")); + SessionState sessionState = mock(SessionState.class); + + assertEquals(Value.int64(-1L), CsvCopyRecord.getSpannerValue(sessionState, Type.int64(), "-1")); + assertEquals(Value.int64(1L), CsvCopyRecord.getSpannerValue(sessionState, Type.int64(), "1")); + assertEquals(Value.int64(0L), CsvCopyRecord.getSpannerValue(sessionState, Type.int64(), "0")); assertEquals( Value.int64(Long.MAX_VALUE), - CsvCopyRecord.getSpannerValue(Type.int64(), "9223372036854775807")); + CsvCopyRecord.getSpannerValue(sessionState, Type.int64(), "9223372036854775807")); assertEquals( Value.int64(Long.MIN_VALUE), - CsvCopyRecord.getSpannerValue(Type.int64(), "-9223372036854775808")); - assertEquals(Value.int64(null), CsvCopyRecord.getSpannerValue(Type.int64(), null)); + CsvCopyRecord.getSpannerValue(sessionState, Type.int64(), "-9223372036854775808")); + assertEquals( + Value.int64(null), CsvCopyRecord.getSpannerValue(sessionState, Type.int64(), null)); } @Test public void testGetSpannerValueFloat64() { - assertEquals(Value.float64(-1.0D), CsvCopyRecord.getSpannerValue(Type.float64(), "-1.0")); - assertEquals(Value.float64(0.0D), CsvCopyRecord.getSpannerValue(Type.float64(), "0.0")); - assertEquals(Value.float64(1.0D), CsvCopyRecord.getSpannerValue(Type.float64(), "1.0")); - assertEquals(Value.float64(null), CsvCopyRecord.getSpannerValue(Type.float64(), null)); + SessionState sessionState = mock(SessionState.class); + + assertEquals( + Value.float64(-1.0D), CsvCopyRecord.getSpannerValue(sessionState, Type.float64(), "-1.0")); + assertEquals( + Value.float64(0.0D), CsvCopyRecord.getSpannerValue(sessionState, Type.float64(), "0.0")); + assertEquals( + Value.float64(1.0D), CsvCopyRecord.getSpannerValue(sessionState, Type.float64(), "1.0")); + assertEquals( + Value.float64(null), CsvCopyRecord.getSpannerValue(sessionState, Type.float64(), null)); } @Test public void testGetSpannerValueNumeric() { - assertEquals(Value.pgNumeric("-1.0"), CsvCopyRecord.getSpannerValue(Type.pgNumeric(), "-1.0")); - assertEquals(Value.pgNumeric("0.0"), CsvCopyRecord.getSpannerValue(Type.pgNumeric(), "0.0")); - assertEquals(Value.pgNumeric("1.0"), CsvCopyRecord.getSpannerValue(Type.pgNumeric(), "1.0")); - assertEquals(Value.pgNumeric(null), CsvCopyRecord.getSpannerValue(Type.pgNumeric(), null)); + SessionState sessionState = mock(SessionState.class); + + assertEquals( + Value.pgNumeric("-1.0"), + CsvCopyRecord.getSpannerValue(sessionState, Type.pgNumeric(), "-1.0")); + assertEquals( + Value.pgNumeric("0.0"), + CsvCopyRecord.getSpannerValue(sessionState, Type.pgNumeric(), "0.0")); + assertEquals( + Value.pgNumeric("1.0"), + CsvCopyRecord.getSpannerValue(sessionState, Type.pgNumeric(), "1.0")); + assertEquals( + Value.pgNumeric(null), CsvCopyRecord.getSpannerValue(sessionState, Type.pgNumeric(), null)); } @Test public void testGetSpannerValueString() { - assertEquals(Value.string("test"), CsvCopyRecord.getSpannerValue(Type.string(), "test")); - assertEquals(Value.string(""), CsvCopyRecord.getSpannerValue(Type.string(), "")); - assertEquals(Value.string(null), CsvCopyRecord.getSpannerValue(Type.string(), null)); + SessionState sessionState = mock(SessionState.class); + + assertEquals( + Value.string("test"), CsvCopyRecord.getSpannerValue(sessionState, Type.string(), "test")); + assertEquals(Value.string(""), CsvCopyRecord.getSpannerValue(sessionState, Type.string(), "")); + assertEquals( + Value.string(null), CsvCopyRecord.getSpannerValue(sessionState, Type.string(), null)); } @Test public void testGetSpannerValueDate() { + SessionState sessionState = mock(SessionState.class); + assertEquals( Value.date(Date.parseDate("2022-08-17")), - CsvCopyRecord.getSpannerValue(Type.date(), "2022-08-17")); - assertEquals(Value.date(null), CsvCopyRecord.getSpannerValue(Type.date(), null)); + CsvCopyRecord.getSpannerValue(sessionState, Type.date(), "2022-08-17")); + assertEquals(Value.date(null), CsvCopyRecord.getSpannerValue(sessionState, Type.date(), null)); } @Test public void testGetSpannerValueTimestamp() { + SessionState sessionState = mock(SessionState.class); + assertEquals( Value.timestamp(Timestamp.parseTimestamp("2093-08-02T14:53:40.481913Z")), - CsvCopyRecord.getSpannerValue(Type.timestamp(), "2093-08-02T14:53:40.481913+00")); + CsvCopyRecord.getSpannerValue( + sessionState, Type.timestamp(), "2093-08-02T14:53:40.481913+00")); assertEquals( Value.timestamp(Timestamp.parseTimestamp("2022-08-17T10:11:12.123456789Z")), - CsvCopyRecord.getSpannerValue(Type.timestamp(), "2022-08-17T10:11:12.123456789Z")); + CsvCopyRecord.getSpannerValue( + sessionState, Type.timestamp(), "2022-08-17T10:11:12.123456789Z")); assertEquals( Value.timestamp(Timestamp.parseTimestamp("2022-08-17T10:11:12.123456789Z")), - CsvCopyRecord.getSpannerValue(Type.timestamp(), "2022-08-17 10:11:12.123456789Z")); + CsvCopyRecord.getSpannerValue( + sessionState, Type.timestamp(), "2022-08-17 10:11:12.123456789Z")); assertEquals( Value.timestamp(Timestamp.parseTimestamp("2022-08-17T10:11:12.123456789Z")), - CsvCopyRecord.getSpannerValue(Type.timestamp(), "2022-08-17 10:11:12.123456789+00")); + CsvCopyRecord.getSpannerValue( + sessionState, Type.timestamp(), "2022-08-17 10:11:12.123456789+00")); assertEquals( Value.timestamp(Timestamp.parseTimestamp("2022-08-17T10:11:12.123456789Z")), - CsvCopyRecord.getSpannerValue(Type.timestamp(), "2022-08-17 10:11:12.123456789+00:00")); + CsvCopyRecord.getSpannerValue( + sessionState, Type.timestamp(), "2022-08-17 10:11:12.123456789+00:00")); assertEquals( Value.timestamp(Timestamp.parseTimestamp("2022-08-17T10:11:12.123456789Z")), - CsvCopyRecord.getSpannerValue(Type.timestamp(), "2022-08-17T10:11:12.123456789+00")); + CsvCopyRecord.getSpannerValue( + sessionState, Type.timestamp(), "2022-08-17T10:11:12.123456789+00")); assertEquals( Value.timestamp(Timestamp.parseTimestamp("2022-08-17T10:11:12.123456789Z")), - CsvCopyRecord.getSpannerValue(Type.timestamp(), "2022-08-17T10:11:12.123456789+00:00")); + CsvCopyRecord.getSpannerValue( + sessionState, Type.timestamp(), "2022-08-17T10:11:12.123456789+00:00")); assertEquals( Value.timestamp(Timestamp.parseTimestamp("2022-08-17T10:11:12.123456789Z")), - CsvCopyRecord.getSpannerValue(Type.timestamp(), "2022-08-17 12:11:12.123456789+02")); + CsvCopyRecord.getSpannerValue( + sessionState, Type.timestamp(), "2022-08-17 12:11:12.123456789+02")); assertEquals( Value.timestamp(Timestamp.parseTimestamp("2022-08-17T10:11:12.123456789Z")), - CsvCopyRecord.getSpannerValue(Type.timestamp(), "2022-08-17T12:11:12.123456789+02")); + CsvCopyRecord.getSpannerValue( + sessionState, Type.timestamp(), "2022-08-17T12:11:12.123456789+02")); assertEquals( Value.timestamp(Timestamp.parseTimestamp("2022-08-17T10:11:12.123456789Z")), - CsvCopyRecord.getSpannerValue(Type.timestamp(), "2022-08-17 08:11:12.123456789-02")); + CsvCopyRecord.getSpannerValue( + sessionState, Type.timestamp(), "2022-08-17 08:11:12.123456789-02")); assertEquals( Value.timestamp(Timestamp.parseTimestamp("2022-08-17T10:11:12.123456789Z")), - CsvCopyRecord.getSpannerValue(Type.timestamp(), "2022-08-17T08:11:12.123456789-02")); - assertEquals(Value.date(null), CsvCopyRecord.getSpannerValue(Type.date(), null)); + CsvCopyRecord.getSpannerValue( + sessionState, Type.timestamp(), "2022-08-17T08:11:12.123456789-02")); + assertEquals(Value.date(null), CsvCopyRecord.getSpannerValue(sessionState, Type.date(), null)); } @Test public void testGetSpannerValue_InvalidBytesValue() { + SessionState sessionState = mock(SessionState.class); + assertThrows( - SpannerException.class, () -> CsvCopyRecord.getSpannerValue(Type.bytes(), "value")); + SpannerException.class, + () -> CsvCopyRecord.getSpannerValue(sessionState, Type.bytes(), "value")); } @Test public void testGetSpannerValue_InvalidNumberValue() { + SessionState sessionState = mock(SessionState.class); + assertThrows( - SpannerException.class, () -> CsvCopyRecord.getSpannerValue(Type.int64(), "value")); + SpannerException.class, + () -> CsvCopyRecord.getSpannerValue(sessionState, Type.int64(), "value")); } @Test public void testGetSpannerValue_InvalidBoolValue() { - assertThrows(SpannerException.class, () -> CsvCopyRecord.getSpannerValue(Type.bool(), "value")); + SessionState sessionState = mock(SessionState.class); + + assertThrows( + SpannerException.class, + () -> CsvCopyRecord.getSpannerValue(sessionState, Type.bool(), "value")); } @Test public void testGetSpannerValue_InvalidDateValue() { - assertThrows(SpannerException.class, () -> CsvCopyRecord.getSpannerValue(Type.date(), "value")); + SessionState sessionState = mock(SessionState.class); + + assertThrows( + SpannerException.class, + () -> CsvCopyRecord.getSpannerValue(sessionState, Type.date(), "value")); } @Test public void testGetSpannerValue_InvalidTimestampValue() { + SessionState sessionState = mock(SessionState.class); + assertThrows( - SpannerException.class, () -> CsvCopyRecord.getSpannerValue(Type.timestamp(), "value")); + SpannerException.class, + () -> CsvCopyRecord.getSpannerValue(sessionState, Type.timestamp(), "value")); } @Test public void testGetSpannerValue_UnsupportedType() { + SessionState sessionState = mock(SessionState.class); + assertThrows( SpannerException.class, () -> CsvCopyRecord.getSpannerValue( - Type.struct(StructField.of("f1", Type.string())), "value")); + sessionState, Type.struct(StructField.of("f1", Type.string())), "value")); } }