From 1bc1ce9d7b9932126cede071dd2ccc8d1df8e13e Mon Sep 17 00:00:00 2001 From: Jacky Zhao Date: Wed, 4 Sep 2024 18:03:11 -0700 Subject: [PATCH] add tests for handlers cleanly exiting (#267) ## Why ## What changed ## Versioning - [ ] Breaking protocol change - [ ] Breaking ts/js API change --- __tests__/cancellation.test.ts | 217 ++++++++++++++++++++++++++++++++- package-lock.json | 4 +- package.json | 2 +- router/server.ts | 15 +-- 4 files changed, 225 insertions(+), 13 deletions(-) diff --git a/__tests__/cancellation.test.ts b/__tests__/cancellation.test.ts index 5dd79ccb..45904d95 100644 --- a/__tests__/cancellation.test.ts +++ b/__tests__/cancellation.test.ts @@ -39,7 +39,222 @@ function makeMockHandler( >(impl); } -describe.each(testMatrix(['ws', 'naive']))( +describe.each(testMatrix())( + 'clean handler cancellation ($transport.name transport, $codec.name codec)', + + async ({ transport, codec }) => { + const opts = { codec: codec.codec }; + + const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups(); + let getClientTransport: TestSetupHelpers['getClientTransport']; + let getServerTransport: TestSetupHelpers['getServerTransport']; + beforeEach(async () => { + const setup = await transport.setup({ client: opts, server: opts }); + getClientTransport = setup.getClientTransport; + getServerTransport = setup.getServerTransport; + + return async () => { + await postTestCleanup(); + await setup.cleanup(); + }; + }); + + describe('e2e', () => { + test('rpc', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + + const signalReceiver = vi.fn<(sig: AbortSignal) => void>(); + const services = { + service: ServiceSchema.define({ + rpc: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + handler: async ({ ctx }) => { + signalReceiver(ctx.signal); + + return Ok({}); + }, + }), + }), + }; + + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + await client.service.rpc.rpc({}); + + await waitFor(() => { + expect(signalReceiver).toHaveBeenCalledTimes(1); + }); + + const [sig] = signalReceiver.mock.calls[0]; + expect(sig.aborted).toEqual(true); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + + test('stream', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + const signalReceiver = vi.fn<(sig: AbortSignal) => void>(); + const services = { + service: ServiceSchema.define({ + stream: Procedure.stream({ + requestInit: Type.Object({}), + requestData: Type.Object({}), + responseData: Type.Object({}), + handler: async ({ ctx, resWritable }) => { + signalReceiver(ctx.signal); + + resWritable.write(Ok({})); + resWritable.close(); + + return; + }, + }), + }), + }; + + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + const { reqWritable, resReadable } = client.service.stream.stream({}); + + await waitFor(() => { + expect(signalReceiver).toHaveBeenCalledTimes(1); + }); + + const [sig] = signalReceiver.mock.calls[0]; + expect(sig.aborted).toEqual(false); + + reqWritable.close(); + await waitFor(() => expect(sig.aborted).toEqual(true)); + + // collect should resolve as the stream has been properly ended + await expect(resReadable.collect()).resolves.toEqual([Ok({})]); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + + test('upload', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + const signalReceiver = vi.fn<(sig: AbortSignal) => void>(); + const services = { + service: ServiceSchema.define({ + upload: Procedure.upload({ + requestInit: Type.Object({}), + requestData: Type.Object({}), + responseData: Type.Object({}), + handler: async ({ ctx }) => { + signalReceiver(ctx.signal); + + return Ok({}); + }, + }), + }), + }; + + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + const { reqWritable, finalize } = client.service.upload.upload({}); + + await waitFor(() => { + expect(signalReceiver).toHaveBeenCalledTimes(1); + }); + + const [sig] = signalReceiver.mock.calls[0]; + expect(sig.aborted).toEqual(false); + + reqWritable.close(); + await waitFor(() => expect(sig.aborted).toEqual(true)); + + expect(await finalize()).toEqual(Ok({})); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + + test('subscribe', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + const signalReceiver = vi.fn<(sig: AbortSignal) => void>(); + const services = { + service: ServiceSchema.define({ + subscribe: Procedure.subscription({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + handler: async ({ ctx, resWritable }) => { + resWritable.close(); + signalReceiver(ctx.signal); + + return; + }, + }), + }), + }; + + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + const { resReadable } = client.service.subscribe.subscribe({}); + + await waitFor(() => { + expect(signalReceiver).toHaveBeenCalledTimes(1); + }); + + const [sig] = signalReceiver.mock.calls[0]; + expect(sig.aborted).toEqual(true); + await expect(resReadable.collect()).resolves.toEqual([]); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + }); + }, +); + +describe.each(testMatrix())( 'client initiated cancellation ($transport.name transport, $codec.name codec)', async ({ transport, codec }) => { const opts = { codec: codec.codec }; diff --git a/package-lock.json b/package-lock.json index 89861420..e67ef326 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@replit/river", - "version": "0.200.4", + "version": "0.200.5", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "@replit/river", - "version": "0.200.4", + "version": "0.200.5", "license": "MIT", "dependencies": { "@msgpack/msgpack": "^3.0.0-beta2", diff --git a/package.json b/package.json index 9e874b3c..f2ef4a88 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "@replit/river", "description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!", - "version": "0.200.4", + "version": "0.200.5", "type": "module", "exports": { ".": { diff --git a/router/server.ts b/router/server.ts index a4df77f5..c181d856 100644 --- a/router/server.ts +++ b/router/server.ts @@ -195,11 +195,10 @@ class RiverServer } // if its not a cancelled stream, validate and create a new stream - const newStream = this.createNewProcStream({ + this.createNewProcStream({ ...newStreamProps, ...message, }); - this.streams.set(streamId, newStream); }; const handleSessionStatus = (evt: EventMap['sessionStatus']) => { @@ -235,7 +234,7 @@ class RiverServer this.transport.addEventListener('transportStatus', handleTransportStatus); } - private createNewProcStream(props: StreamInitProps): ProcStream { + private createNewProcStream(props: StreamInitProps) { const { streamId, initialSession, @@ -369,6 +368,7 @@ class RiverServer }); }; + const finishedController = new AbortController(); const procStream: ProcStream = { from: from, streamId, @@ -422,7 +422,6 @@ class RiverServer cancelStream(streamId, result); }; - const finishedController = new AbortController(); const cleanup = () => { finishedController.abort(); this.streams.delete(streamId); @@ -531,10 +530,6 @@ class RiverServer // only consists of an init message and we shouldn't expect follow up data if (procClosesWithInit) { closeReadable(); - } else if (procedure.type === 'rpc' || procedure.type === 'subscription') { - // Though things can work just fine if they eventually follow up with a stream - // control message with a close bit set, it's an unusual client implementation! - this.log?.warn('sent an init without a stream close', loggingMetadata); } const handlerContext: ProcedureHandlerContext = { @@ -658,7 +653,9 @@ class RiverServer break; } - return procStream; + if (!finishedController.signal.aborted) { + this.streams.set(streamId, procStream); + } } private getContext(service: AnyService, serviceName: string) {