diff --git a/packages/adapters/database/src/client.ts b/packages/adapters/database/src/client.ts index d1ffd0e91e..c983cec81c 100644 --- a/packages/adapters/database/src/client.ts +++ b/packages/adapters/database/src/client.ts @@ -901,6 +901,16 @@ export const getAggregateRoot = async ( return aggregateRootId.split("-")[0] ?? undefined; }; +export const getBaseAggregateRootCount = async ( + received_root: string, + _pool?: Pool | db.TxnClientForRepeatableRead, +): Promise => { + const poolToUse = _pool ?? pool; + // Get the leaf count at the aggregated root + const root = await db.selectOne("aggregated_roots", { received_root }).run(poolToUse); + return root ? convertFromDbAggregatedRoot(root).index : undefined; +}; + export const getAggregateRootCount = async ( aggreateRoot: string, _pool?: Pool | db.TxnClientForRepeatableRead, diff --git a/packages/adapters/database/src/index.ts b/packages/adapters/database/src/index.ts index 2cdc78c833..4363b0e496 100644 --- a/packages/adapters/database/src/index.ts +++ b/packages/adapters/database/src/index.ts @@ -57,6 +57,7 @@ import { getAggregateRoot, getAggregateRootByRootAndDomain, getAggregateRootCount, + getBaseAggregateRootCount, getAggregateRoots, getBaseAggregateRoot, getMessageRootIndex, @@ -203,6 +204,10 @@ export type Database = { aggregateRoot: string, _pool?: Pool | TxnClientForRepeatableRead, ) => Promise; + getBaseAggregateRootCount: ( + aggregateRoot: string, + _pool?: Pool | TxnClientForRepeatableRead, + ) => Promise; getAggregateRoots: (count: number, _pool?: Pool | TxnClientForRepeatableRead) => Promise; getBaseAggregateRoot: (_pool?: Pool | TxnClientForRepeatableRead) => Promise; getCurrentProposedSnapshot: (_pool?: Pool | TxnClientForRepeatableRead) => Promise; @@ -374,6 +379,7 @@ export const getDatabase = async (databaseUrl: string, logger: Logger): Promise< getAggregateRoot, getAggregateRootByRootAndDomain, getAggregateRootCount, + getBaseAggregateRootCount, getAggregateRoots, getBaseAggregateRoot, getMessageRootIndex, @@ -461,6 +467,7 @@ export const getDatabaseAndPool = async ( getAggregateRoot, getAggregateRootByRootAndDomain, getAggregateRootCount, + getBaseAggregateRootCount, getAggregateRoots, getBaseAggregateRoot, getMessageRootIndex, diff --git a/packages/adapters/database/test/client.spec.ts b/packages/adapters/database/test/client.spec.ts index 77fae21a97..3045d99fbd 100644 --- a/packages/adapters/database/test/client.spec.ts +++ b/packages/adapters/database/test/client.spec.ts @@ -56,6 +56,7 @@ import { getLatestMessageRoot, getLatestAggregateRoots, getAggregateRootCount, + getBaseAggregateRootCount, getUnProcessedMessages, getUnProcessedMessagesByIndex, getUnProcessedMessagesByDomains, @@ -1096,6 +1097,7 @@ describe("Database client", () => { expect(await getMessageRootCount("", "", pool)).to.eq(undefined); expect(await getMessageRootIndex("", "", pool)).to.eq(undefined); expect(await getAggregateRootCount("", pool)).to.eq(undefined); + expect(await getBaseAggregateRootCount("", pool)).to.eq(undefined); expect(await getAggregateRoot("", pool)).to.eq(undefined); expect(await getLatestMessageRoot("", "", pool)).to.eq(undefined); expect(await getLatestAggregateRoots("", 1, "DESC", pool)).to.be.deep.eq([]); @@ -1130,6 +1132,7 @@ describe("Database client", () => { ).to.eventually.not.be.rejected; await expect(getAggregateRoot(undefined as any, undefined as any)).to.eventually.not.be.rejected; await expect(getAggregateRootCount(undefined as any, undefined as any)).to.eventually.not.be.rejected; + await expect(getBaseAggregateRootCount(undefined as any, undefined as any)).to.eventually.not.be.rejected; await expect(getMessageRootIndex(undefined as any, undefined as any, undefined as any)).to.eventually.not.be .rejected; await expect(getMessageRootAggregatedFromIndex(undefined as any, undefined as any, undefined as any)).to.eventually diff --git a/packages/adapters/database/test/mock.ts b/packages/adapters/database/test/mock.ts index 310ec8f072..c886024ae3 100644 --- a/packages/adapters/database/test/mock.ts +++ b/packages/adapters/database/test/mock.ts @@ -31,6 +31,7 @@ export const mockDatabase = (): Database => { getUnProcessedMessagesByDomains: stub().resolves([]), getAggregateRoot: stub().resolves(), getAggregateRootCount: stub().resolves(), + getBaseAggregateRootCount: stub().resolves(), getMessageRootIndex: stub().resolves(), getMessageRootAggregatedFromIndex: stub().resolves(), getMessageRootCount: stub().resolves(), diff --git a/packages/agents/lighthouse/src/tasks/propose/operations/propose.ts b/packages/agents/lighthouse/src/tasks/propose/operations/propose.ts index 9ff0750ef1..1018ec561d 100644 --- a/packages/agents/lighthouse/src/tasks/propose/operations/propose.ts +++ b/packages/agents/lighthouse/src/tasks/propose/operations/propose.ts @@ -107,7 +107,7 @@ export const proposeSnapshot = async (snapshotId: string, snapshotRoots: string[ throw new NoBaseAggregateRoot(); } - const baseAggregateRootCount = await database.getAggregateRootCount(baseAggregateRoot); + const baseAggregateRootCount = await database.getBaseAggregateRootCount(baseAggregateRoot); if (!baseAggregateRootCount) { throw new NoBaseAggregateRootCount(baseAggregateRoot); } diff --git a/packages/agents/lighthouse/test/tasks/propose/operations/propose.spec.ts b/packages/agents/lighthouse/test/tasks/propose/operations/propose.spec.ts index b1fbd5b459..c9b5d8eaad 100644 --- a/packages/agents/lighthouse/test/tasks/propose/operations/propose.spec.ts +++ b/packages/agents/lighthouse/test/tasks/propose/operations/propose.spec.ts @@ -37,7 +37,7 @@ describe("Operations: Propose", () => { it("happy case should call propose snapshot succesfully", async () => { (proposeCtxMock.adapters.database.getBaseAggregateRoot as SinonStub).resolves("0x"); - (proposeCtxMock.adapters.database.getAggregateRootCount as SinonStub).resolves(1); + (proposeCtxMock.adapters.database.getBaseAggregateRootCount as SinonStub).resolves(1); (proposeCtxMock.adapters.database.getAggregateRoots as SinonStub).resolves(["0x"]); (proposeCtxMock.adapters.database.getPendingSnapshots as SinonStub).resolves([mock.entity.snapshotRoot()]); encodeFunctionData.returns("0x");