Skip to content

Commit

Permalink
fix: correct type inference for arithmetic operations (#359)
Browse files Browse the repository at this point in the history
This update addresses an issue where the inferred type was incorrect when dealing with arithmetic operations in SQL queries. The changes ensure that the type is accurately resolved based on the operation and operand types.
  • Loading branch information
Newbie012 authored Jan 6, 2025
1 parent 8502e67 commit 226d29e
Show file tree
Hide file tree
Showing 10 changed files with 1,083 additions and 222 deletions.
5 changes: 5 additions & 0 deletions .changeset/chatty-hairs-hide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@ts-safeql/generate": patch
---

fixed an issue where the inferred typed was incorrect when dealing with arithmetic operations
2 changes: 1 addition & 1 deletion packages/eslint-plugin/src/rules/check-sql.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1290,7 +1290,7 @@ RuleTester.describe("check-sql", () => {
await sql<Caregiver[]>\`
SELECT
CASE WHEN caregiver.id IS NOT NULL
THEN jsonb_build_object('is_test', caregiver.middle_name NOT LIKE '%test%')
THEN jsonb_build_object('is_test', caregiver.first_name LIKE '%test%')
ELSE NULL
END AS meta
FROM
Expand Down
2 changes: 1 addition & 1 deletion packages/generate/src/ast-decribe.utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ export function isSingleCell<T>(arr: T[]): arr is [T] {
return arr.length === 1;
}

function isTuple<T>(arr: T[]): arr is [T, T] {
export function isTuple<T>(arr: T[]): arr is [T, T] {
return arr.length === 2;
}
133 changes: 112 additions & 21 deletions packages/generate/src/ast-describe.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { fmap, normalizeIndent } from "@ts-safeql/shared";
import { defaultTypeExprMapping, fmap, normalizeIndent } from "@ts-safeql/shared";
import * as LibPgQueryAST from "@ts-safeql/sql-ast";
import {
isColumnStarRef,
isColumnTableColumnRef,
isColumnTableStarRef,
isColumnUnknownRef,
isSingleCell,
isTuple,
} from "./ast-decribe.utils";
import { ResolvedColumn, SourcesResolver, getSources } from "./ast-get-sources";
import { PgColRow, PgEnumsMaps, PgTypesMap } from "./generate";
Expand All @@ -20,7 +21,7 @@ type ASTDescriptionOptions = {
pgColsBySchemaAndTableName: Map<string, Map<string, PgColRow[]>>;
pgTypes: PgTypesMap;
pgEnums: PgEnumsMaps;
pgFns: Map<string, string>;
pgFns: Map<string, { ts: string; pg: string }>;
};

type ASTDescriptionContext = ASTDescriptionOptions & {
Expand All @@ -38,7 +39,7 @@ export type ASTDescribedColumnType =
| { kind: "union"; value: ASTDescribedColumnType[] }
| { kind: "array"; value: ASTDescribedColumnType }
| { kind: "object"; value: [string, ASTDescribedColumnType][] }
| { kind: "type"; value: string }
| { kind: "type"; value: string; type: string }
| { kind: "literal"; value: string; base: ASTDescribedColumnType };

export function getASTDescription(params: ASTDescriptionOptions): Map<number, ASTDescribedColumn> {
Expand Down Expand Up @@ -82,20 +83,32 @@ export function getASTDescription(params: ASTDescriptionOptions): Map<number, AS
p: { oid: number; baseOid: number | null } | { name: string },
): ASTDescribedColumnType => {
if ("name" in p) {
return { kind: "type", value: params.typesMap.get(p.name)?.value ?? "unknown" };
return {
kind: "type",
value: params.typesMap.get(p.name)?.value ?? "unknown",
type: p.name,
};
}

const typeByOid = getTypeByOid(p.oid);

if (typeByOid.override) {
const baseType: ASTDescribedColumnType = { kind: "type", value: typeByOid.value };
const baseType: ASTDescribedColumnType = {
kind: "type",
value: typeByOid.value,
type: params.pgTypes.get(p.oid)?.name ?? "unknown",
};
return typeByOid.isArray ? { kind: "array", value: baseType } : baseType;
}

const typeByBaseOid = fmap(p.baseOid, getTypeByOid);

if (typeByBaseOid?.override === true) {
const baseType: ASTDescribedColumnType = { kind: "type", value: typeByBaseOid.value };
const baseType: ASTDescribedColumnType = {
kind: "type",
value: typeByBaseOid.value,
type: params.pgTypes.get(p.baseOid!)?.name ?? "unknown",
};
return typeByBaseOid.isArray ? { kind: "array", value: baseType } : baseType;
}

Expand All @@ -104,13 +117,21 @@ export function getASTDescription(params: ASTDescriptionOptions): Map<number, AS
if (enumValue !== undefined) {
return {
kind: "union",
value: enumValue.values.map((value) => ({ kind: "type", value: `'${value}'` })),
value: enumValue.values.map((value) => ({
kind: "type",
value: `'${value}'`,
type: enumValue.name,
})),
};
}

const { isArray, value } = typeByBaseOid ?? typeByOid;

const type: ASTDescribedColumnType = { kind: "type", value: value };
const type: ASTDescribedColumnType = {
kind: "type",
value: value,
type: params.pgTypes.get(p.oid)?.name ?? "unknown",
};

return isArray ? { kind: "array", value: type } : type;
},
Expand Down Expand Up @@ -215,15 +236,81 @@ function getDescribedNode(params: {

function getDescribedAExpr({
alias,
node,
context,
}: GetDescribedParamsOf<LibPgQueryAST.AExpr>): ASTDescribedColumn[] {
const name = alias ?? "?column?";

if (node.lexpr === undefined && node.rexpr !== undefined) {
const described = getDescribedNode({ alias, node: node.rexpr, context }).at(0);
const type = fmap(described, (x) => getBaseType(x.type));

if (type === null) return [];

return [{ name, type }];
}

if (node.lexpr === undefined || node.rexpr === undefined) {
return [];
}

const getResolvedNullableValueOrNull = (node: LibPgQueryAST.Node) => {
const column = getDescribedNode({ alias: undefined, node, context }).at(0);

if (column === undefined) return null;

if (column.type.kind === "array") {
return { value: "array", nullable: false };
}

if (column.type.kind === "type") {
return { value: column.type.type, nullable: false };
}

if (column.type.kind === "literal" && column.type.base.kind === "type") {
return { value: column.type.base.type, nullable: false };
}

if (column.type.kind === "union" && isTuple(column.type.value)) {
let nullable = false;
let value: string | undefined = undefined;

for (const type of column.type.value) {
if (type.kind !== "type") return null;
if (type.value === "null") nullable = true;
if (type.value !== "null") value = type.type;
}

if (value === undefined) return null;

return { value, nullable };
}

return null;
};

const lnode = getResolvedNullableValueOrNull(node.lexpr);
const rnode = getResolvedNullableValueOrNull(node.rexpr);

if (lnode === null || rnode === null) {
return [];
}

const operator = concatStringNodes(node.name);
const resolved: string | undefined =
defaultTypeExprMapping[`${lnode.value} ${operator} ${rnode.value}`];

if (resolved === undefined) {
return [];
}

return [
{
name: alias ?? "?column?",
name: name,
type: resolveType({
context: context,
nullable: false,
type: context.toTypeScriptType({ name: "boolean" }),
nullable: !context.nonNullableColumns.has(name) && (lnode.nullable || rnode.nullable),
type: context.toTypeScriptType({ name: resolved }),
}),
},
];
Expand All @@ -239,7 +326,7 @@ function getDescribedNullTest({
type: resolveType({
context: context,
nullable: false,
type: context.toTypeScriptType({ name: "boolean" }),
type: context.toTypeScriptType({ name: "bool" }),
}),
},
];
Expand Down Expand Up @@ -298,7 +385,7 @@ function getDescribedBoolExpr({
type: resolveType({
context: context,
nullable: false,
type: context.toTypeScriptType({ name: "boolean" }),
type: context.toTypeScriptType({ name: "bool" }),
}),
},
];
Expand All @@ -317,7 +404,7 @@ function getDescribedSubLink({
nullable: false,
type: (() => {
if (node.subLinkType === LibPgQueryAST.SubLinkType.EXISTS_SUBLINK) {
return context.toTypeScriptType({ name: "boolean" });
return context.toTypeScriptType({ name: "bool" });
}

return context.toTypeScriptType({ name: "unknown" });
Expand Down Expand Up @@ -412,7 +499,7 @@ function mergeDescribedColumnTypes(types: ASTDescribedColumnType[]): ASTDescribe

if (!seenSymbols.has("boolean") && seenSymbols.has("true") && seenSymbols.has("false")) {
seenSymbols.add("boolean");
result.push({ kind: "type", value: "boolean" });
result.push({ kind: "type", value: "boolean", type: "bool" });
}

if (seenSymbols.has("boolean") && (seenSymbols.has("true") || seenSymbols.has("false"))) {
Expand Down Expand Up @@ -537,15 +624,15 @@ function getDescribedFuncCallByPgFn({

const pgFnValue =
args.length === 0
? context.pgFns.get(functionName)
? (context.pgFns.get(functionName) ?? context.pgFns.get(`${functionName}(string)`))
: (context.pgFns.get(`${functionName}(${args.join(", ")})`) ??
context.pgFns.get(`${functionName}(any)`) ??
context.pgFns.get(`${functionName}(unknown)`));

const type = resolveType({
context: context,
nullable: !context.nonNullableColumns.has(name),
type: { kind: "type", value: pgFnValue ?? "unknown" },
type: { kind: "type", value: pgFnValue?.ts ?? "unknown", type: pgFnValue?.pg ?? "unknown" },
});

return [{ name, type }];
Expand Down Expand Up @@ -758,7 +845,11 @@ function getDescribedColumnByResolvedColumns(params: {
?.get(column.colName);

if (overridenType !== undefined) {
return { kind: "type", value: overridenType };
return {
kind: "type",
value: overridenType,
type: params.context.pgTypes.get(column.colTypeOid)?.name ?? "unknown",
};
}

return params.context.toTypeScriptType({
Expand Down Expand Up @@ -789,7 +880,7 @@ function getDescribedAConst({
return {
kind: "literal",
value: node.boolval.boolval ? "true" : "false",
base: context.toTypeScriptType({ name: "boolean" }),
base: context.toTypeScriptType({ name: "bool" }),
};
case node.bsval !== undefined:
return context.toTypeScriptType({ name: "bytea" });
Expand Down Expand Up @@ -838,7 +929,7 @@ function asNonNullableType(type: ASTDescribedColumnType): ASTDescribedColumnType
);

if (filtered.length === 0) {
return { kind: "type", value: "unknown" };
return { kind: "type", value: "unknown", type: "unknown" };
}

if (filtered.length === 1) {
Expand All @@ -848,7 +939,7 @@ function asNonNullableType(type: ASTDescribedColumnType): ASTDescribedColumnType
return { kind: "union", value: filtered };
}
case "type":
return type.value === "null" ? { kind: "type", value: "unknown" } : type;
return type.value === "null" ? { kind: "type", value: "unknown", type: "unknown" } : type;
}
}

Expand Down
Loading

0 comments on commit 226d29e

Please sign in to comment.