diff --git a/README.md b/README.md index 62f684b..004d439 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,8 @@ npm i @bdelab/jscat ## Usage +For existing jsCAT users: to make your applications compatible to the updated jsCAT version, you will need to pass the stimuli in the following way: + ```js // import jsCAT import { Cat, normal } from '@bdelab/jscat'; @@ -26,7 +28,15 @@ const currentPrior = normal(); // create a Cat object const cat = new CAT({method: 'MLE', itemSelect: 'MFI', nStartItems: 0, theta: 0, minTheta: -6, maxTheta: 6, prior: currentPrior}) -// update the abilitiy estimate by adding test items +// option 1 to input stimuli: +const zeta = {[{discrimination: 1, difficulty: 0, guessing: 0, slipping: 1}, {discrimination: 1, difficulty: 0.5, guessing: 0, slipping: 1}]} + +// option 2 to input stimuli: +const zeta = {[{a: 1, b: 0, c: 0, d: 1}, {a: 1, b: 0.5, c: 0, d: 1}]} + +const answer = {[1, 0]} + +// update the ability estimate by adding test items cat.updateAbilityEstimate(zeta, answer); const currentTheta = cat.theta; @@ -37,29 +47,231 @@ const numItems = cat.nItems; // find the next available item from an input array of stimuli based on a selection method -const stimuli = [{difficulty: -3, item: 'item1'}, {difficulty: -2, item: 'item2'}]; +> **Note:** For existing jsCAT users: To make your applications compatible with the updated jsCAT version, you will need to pass the stimuli in the following way: + +const stimuli = [{ discrimination: 1, difficulty: -2, guessing: 0, slipping: 1, item = "item1" },{ discrimination: 1, difficulty: 3, guessing: 0, slipping: 1, item = "item2" }]; const nextItem = cat.findNextItem(stimuli, 'MFI'); ``` -## Validations + +## Validation + ### Validation of theta estimate and theta standard error + Reference software: mirt (Chalmers, 2012) ![img.png](validation/plots/jsCAT_validation_1.png) ### Validation of MFI algorithm + Reference software: catR (Magis et al., 2017) ![img_1.png](validation/plots/jsCAT_validation_2.png) +# Clowder Usage Guide + +The `Clowder` class is a powerful tool for managing multiple `Cat` instances and handling stimuli corpora in adaptive testing scenarios. This guide provides an overview of integrating `Clowder` into your application, with examples and explanations for key features. + +--- + +## Key Changes from Single `Cat` to `Clowder` + +### Why Use Clowder? + +- **Multi-CAT Support**: Manage multiple `Cat` instances simultaneously. +- **Centralized Corpus Management**: Handle validated and unvalidated items across Cats. +- **Advanced Trial Management**: Dynamically update Cats and retrieve stimuli based on configurable rules. +- **Early Stopping Mechanisms**: Optimize testing by integrating conditions to stop trials automatically. + +--- + +## Transitioning to Clowder + +### 1. Replacing Single `Cat` Usage + +#### Single `Cat` Example: +```typescript +const cat = new Cat({ method: 'MLE', theta: 0.5 }); +const nextItem = cat.findNextItem(stimuli); +``` + +#### Clowder Equivalent: +```typescript +const clowder = new Clowder({ + cats: { cat1: { method: 'MLE', theta: 0.5 } }, + corpus: stimuli, +}); +const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', +}); +``` + +--- + +### 2. Using a Corpus with Multi-Zeta Stimuli + +The `Clowder` corpus supports multi-zeta stimuli, allowing each stimulus to define parameters for multiple Cats. Use the following tools to prepare the corpus: + +#### Fill Default Zeta Parameters: + +```typescript +import { fillZetaDefaults } from './corpus'; + +const filledStimuli = stimuli.map((stim) => fillZetaDefaults(stim)); + +``` + +**What is `fillZetaDefaults`?** +The function `fillZetaDefaults` ensures that each stimulus in the corpus has Zeta parameters defined. If any parameters are missing, it fills them with the default Zeta values. + +The default values are: + +```typescript +export const defaultZeta = (desiredFormat: 'symbolic' | 'semantic' = 'symbolic'): Zeta => { + const defaultZeta: Zeta = { + a: 1, + b: 0, + c: 0, + d: 1, + }; + + return convertZeta(defaultZeta, desiredFormat); +}; + +``` +- If desiredFormat is not specified, it defaults to 'symbolic'. +- This ensures consistency across different stimuli and prevents errors from missing Zeta parameters. +- You can pass 'semantic' as an argument to convert the default Zeta values into a different representation. + +#### Validate the Corpus: +```typescript +import { checkNoDuplicateCatNames } from './corpus'; +checkNoDuplicateCatNames(corpus); +``` + +#### Filter Stimuli for a Specific Cat: +```typescript +import { filterItemsByCatParameterAvailability } from './corpus'; +const { available, missing } = filterItemsByCatParameterAvailability(corpus, 'cat1'); +``` + +--- + +### 3. Adding Early Stopping + +Integrate early stopping mechanisms to optimize the testing process. + +#### Example: Stop After N Items +```typescript +import { StopAfterNItems } from './stopping'; + +const earlyStopping = new StopAfterNItems({ + requiredItems: { cat1: 2 }, +}); + +const clowder = new Clowder({ + cats: { cat1: { method: 'MLE', theta: 0.5 } }, + corpus: stimuli, + earlyStopping: earlyStopping, +}); +``` + +## Early Stopping Criteria Combinations + +To clarify the available combinations for early stopping, here’s a breakdown of the options you can use: + +### 1. Logical Operations + +You can combine multiple stopping criteria using one of the following logical operations: + +- **`and`**: All conditions need to be met to trigger early stopping. +- **`or`**: Any one condition being met will trigger early stopping. +- **`only`**: Only a specific condition is considered (you need to specify the cat to evaluate). + +### 2. Stopping Criteria Classes + +There are different types of stopping criteria you can configure: + +- **`StopAfterNItems`**: Stops the process after a specified number of items. +- **`StopOnSEMeasurementPlateau`**: Stops if the standard error (SE) of measurement remains stable (within a defined tolerance) for a specified number of items. +- **`StopIfSEMeasurementBelowThreshold`**: Stops if the SE measurement drops below a set threshold. + +### How Combinations Work + +You can mix and match these criteria with different logical operations, giving you a range of configurations for early stopping. For example: + +- Using **`and`** with both `StopAfterNItems` and `StopIfSEMeasurementBelowThreshold` means stopping will only occur if both conditions are satisfied. +- Using **`or`** with `StopOnSEMeasurementPlateau` and `StopAfterNItems` allows early stopping if either condition is met. + +--- + +## Clowder Example + +Here’s a complete example demonstrating how to configure and use `Clowder`: + +```typescript +import { Clowder } from './clowder'; +import { createMultiZetaStimulus, createZetaCatMap } from './utils'; +import { StopAfterNItems } from './stopping'; + +// Define the Cats +const catConfigs = { + cat1: { method: 'MLE', theta: 0.5 }, // Cat1 uses Maximum Likelihood Estimation + cat2: { method: 'EAP', theta: -1.0 }, // Cat2 uses Expected A Posteriori +}; + +// Define the corpus +const corpus = [ + createMultiZetaStimulus('item1', [ + createZetaCatMap(['cat1'], { a: 1, b: 0.5, c: 0.2, d: 0.8 }), + createZetaCatMap(['cat2'], { a: 2, b: 0.7, c: 0.3, d: 0.9 }), + ]), + createMultiZetaStimulus('item2', [createZetaCatMap(['cat1'], { a: 1.5, b: 0.4, c: 0.1, d: 0.85 })]), + createMultiZetaStimulus('item3', [createZetaCatMap(['cat2'], { a: 2.5, b: 0.6, c: 0.25, d: 0.95 })]), + createMultiZetaStimulus('item4', []), // Unvalidated item +]; + +// Optional: Add an early stopping strategy +const earlyStopping = new StopAfterNItems({ + requiredItems: { cat1: 2, cat2: 2 }, +}); + +// Initialize the Clowder +const clowder = new Clowder({ + cats: catConfigs, + corpus: corpus, + earlyStopping: earlyStopping, +}); + +// Running Trials +const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1', 'cat2'], // Update responses for both Cats + items: [clowder.corpus[0]], // Previously seen item + answers: [1], // Response for the previously seen item +}); + +console.log('Next item to present:', nextItem); + +// Check stopping condition +if (clowder.earlyStopping?.earlyStop) { + console.log('Early stopping triggered:', clowder.stoppingReason); +} +``` + +--- + +By integrating `Clowder`, your application can efficiently manage adaptive testing scenarios with robust trial and stimuli handling, multi-CAT configurations, and stopping conditions to ensure optimal performance. ## References -Chalmers, R. P. (2012). mirt: A multidimensional item response theory package for the R environment. Journal of Statistical Software. -Magis, D., & Barrada, J. R. (2017). Computerized adaptive testing with R: Recent updates of the package catR. Journal of Statistical Software, 76, 1-19. +- Chalmers, R. P. (2012). mirt: A multidimensional item response theory package for the R environment. Journal of Statistical Software. + +- Magis, D., & Barrada, J. R. (2017). Computerized adaptive testing with R: Recent updates of the package catR. Journal of Statistical Software, 76, 1-19. -Lucas Duailibe, irt-js, (2019), GitHub repository, https://github.com/geekie/irt-js +- Lucas Duailibe, irt-js, (2019), GitHub repository, https://github.com/geekie/irt-js ## License + jsCAT is distributed under the [ISC license](LICENSE). ## Contributors @@ -78,4 +290,4 @@ Ma, W. A., Richie-Halford, A., Burkhardt, A. K., Kanopka, K., Chou, C., Domingue pages={1--17}, year={2025}, publisher={Springer} -} +} \ No newline at end of file diff --git a/package-lock.json b/package-lock.json index e11a89c..144274e 100644 --- a/package-lock.json +++ b/package-lock.json @@ -16,7 +16,7 @@ "seedrandom": "^3.0.5" }, "devDependencies": { - "@types/jest": "^28.1.6", + "@types/jest": "^28.1.8", "@types/lodash": "^4.14.182", "@types/seedrandom": "^3.0.2", "@typescript-eslint/eslint-plugin": "^5.30.7", @@ -24,8 +24,9 @@ "eslint": "^8.20.0", "eslint-config-prettier": "^8.5.0", "jest": "^28.1.3", + "jest-extended": "^4.0.2", "prettier": "^2.7.1", - "ts-jest": "^28.0.7", + "ts-jest": "^28.0.8", "tsdoc": "^0.0.4", "typescript": "^4.7.4" } @@ -1863,12 +1864,13 @@ } }, "node_modules/@types/jest": { - "version": "28.1.6", - "resolved": "https://registry.npmjs.org/@types/jest/-/jest-28.1.6.tgz", - "integrity": "sha512-0RbGAFMfcBJKOmqRazM8L98uokwuwD5F8rHrv/ZMbrZBwVOWZUyPG6VFNscjYr/vjM3Vu4fRrCPbOs42AfemaQ==", + "version": "28.1.8", + "resolved": "https://registry.npmjs.org/@types/jest/-/jest-28.1.8.tgz", + "integrity": "sha512-8TJkV++s7B6XqnDrzR1m/TT0A0h948Pnl/097veySPN67VRAgQ4gZ7n2KfJo2rVq6njQjdxU3GCCyDvAeuHoiw==", "dev": true, + "license": "MIT", "dependencies": { - "jest-matcher-utils": "^28.0.0", + "expect": "^28.0.0", "pretty-format": "^28.0.0" } }, @@ -6943,6 +6945,188 @@ "node": "^12.13.0 || ^14.15.0 || ^16.10.0 || >=17.0.0" } }, + "node_modules/jest-extended": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/jest-extended/-/jest-extended-4.0.2.tgz", + "integrity": "sha512-FH7aaPgtGYHc9mRjriS0ZEHYM5/W69tLrFTIdzm+yJgeoCmmrSB/luSfMSqWP9O29QWHPEmJ4qmU6EwsZideog==", + "dev": true, + "license": "MIT", + "dependencies": { + "jest-diff": "^29.0.0", + "jest-get-type": "^29.0.0" + }, + "engines": { + "node": "^14.15.0 || ^16.10.0 || >=18.0.0" + }, + "peerDependencies": { + "jest": ">=27.2.5" + }, + "peerDependenciesMeta": { + "jest": { + "optional": true + } + } + }, + "node_modules/jest-extended/node_modules/@jest/schemas": { + "version": "29.6.3", + "resolved": "https://registry.npmjs.org/@jest/schemas/-/schemas-29.6.3.tgz", + "integrity": "sha512-mo5j5X+jIZmJQveBKeS/clAueipV7KgiX1vMgCxam1RNYiqE1w62n0/tJJnHtjW8ZHcQco5gY85jA3mi0L+nSA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@sinclair/typebox": "^0.27.8" + }, + "engines": { + "node": "^14.15.0 || ^16.10.0 || >=18.0.0" + } + }, + "node_modules/jest-extended/node_modules/@sinclair/typebox": { + "version": "0.27.8", + "resolved": "https://registry.npmjs.org/@sinclair/typebox/-/typebox-0.27.8.tgz", + "integrity": "sha512-+Fj43pSMwJs4KRrH/938Uf+uAELIgVBmQzg/q1YG10djyfA3TnrU8N8XzqCh/okZdszqBQTZf96idMfE5lnwTA==", + "dev": true, + "license": "MIT" + }, + "node_modules/jest-extended/node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "license": "MIT", + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/jest-extended/node_modules/chalk": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", + "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" + } + }, + "node_modules/jest-extended/node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/jest-extended/node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "dev": true, + "license": "MIT" + }, + "node_modules/jest-extended/node_modules/diff-sequences": { + "version": "29.6.3", + "resolved": "https://registry.npmjs.org/diff-sequences/-/diff-sequences-29.6.3.tgz", + "integrity": "sha512-EjePK1srD3P08o2j4f0ExnylqRs5B9tJjcp9t1krH2qRi8CCdsYfwe9JgSLurFBWwq4uOlipzfk5fHNvwFKr8Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^14.15.0 || ^16.10.0 || >=18.0.0" + } + }, + "node_modules/jest-extended/node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/jest-extended/node_modules/jest-diff": { + "version": "29.7.0", + "resolved": "https://registry.npmjs.org/jest-diff/-/jest-diff-29.7.0.tgz", + "integrity": "sha512-LMIgiIrhigmPrs03JHpxUh2yISK3vLFPkAodPeo0+BuF7wA2FoQbkEg1u8gBYBThncu7e1oEDUfIXVuTqLRUjw==", + "dev": true, + "license": "MIT", + "dependencies": { + "chalk": "^4.0.0", + "diff-sequences": "^29.6.3", + "jest-get-type": "^29.6.3", + "pretty-format": "^29.7.0" + }, + "engines": { + "node": "^14.15.0 || ^16.10.0 || >=18.0.0" + } + }, + "node_modules/jest-extended/node_modules/jest-get-type": { + "version": "29.6.3", + "resolved": "https://registry.npmjs.org/jest-get-type/-/jest-get-type-29.6.3.tgz", + "integrity": "sha512-zrteXnqYxfQh7l5FHyL38jL39di8H8rHoecLH3JNxH3BwOrBsNeabdap5e0I23lD4HHI8W5VFBZqG4Eaq5LNcw==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^14.15.0 || ^16.10.0 || >=18.0.0" + } + }, + "node_modules/jest-extended/node_modules/pretty-format": { + "version": "29.7.0", + "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-29.7.0.tgz", + "integrity": "sha512-Pdlw/oPxN+aXdmM9R00JVC9WVFoCLTKJvDVLgmJ+qAffBMxsV85l/Lu7sNx4zSzPyoL2euImuEwHhOXdEgNFZQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jest/schemas": "^29.6.3", + "ansi-styles": "^5.0.0", + "react-is": "^18.0.0" + }, + "engines": { + "node": "^14.15.0 || ^16.10.0 || >=18.0.0" + } + }, + "node_modules/jest-extended/node_modules/pretty-format/node_modules/ansi-styles": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-5.2.0.tgz", + "integrity": "sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/jest-extended/node_modules/supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, + "license": "MIT", + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/jest-get-type": { "version": "28.0.2", "resolved": "https://registry.npmjs.org/jest-get-type/-/jest-get-type-28.0.2.tgz", @@ -15944,10 +16128,11 @@ } }, "node_modules/ts-jest": { - "version": "28.0.7", - "resolved": "https://registry.npmjs.org/ts-jest/-/ts-jest-28.0.7.tgz", - "integrity": "sha512-wWXCSmTwBVmdvWrOpYhal79bDpioDy4rTT+0vyUnE3ZzM7LOAAGG9NXwzkEL/a516rQEgnMmS/WKP9jBPCVJyA==", + "version": "28.0.8", + "resolved": "https://registry.npmjs.org/ts-jest/-/ts-jest-28.0.8.tgz", + "integrity": "sha512-5FaG0lXmRPzApix8oFG8RKjAz4ehtm8yMKOTy5HX3fY6W8kmvOrmcY0hKDElW52FJov+clhUbrKAqofnj4mXTg==", "dev": true, + "license": "MIT", "dependencies": { "bs-logger": "0.x", "fast-json-stable-stringify": "2.x", @@ -18107,12 +18292,12 @@ } }, "@types/jest": { - "version": "28.1.6", - "resolved": "https://registry.npmjs.org/@types/jest/-/jest-28.1.6.tgz", - "integrity": "sha512-0RbGAFMfcBJKOmqRazM8L98uokwuwD5F8rHrv/ZMbrZBwVOWZUyPG6VFNscjYr/vjM3Vu4fRrCPbOs42AfemaQ==", + "version": "28.1.8", + "resolved": "https://registry.npmjs.org/@types/jest/-/jest-28.1.8.tgz", + "integrity": "sha512-8TJkV++s7B6XqnDrzR1m/TT0A0h948Pnl/097veySPN67VRAgQ4gZ7n2KfJo2rVq6njQjdxU3GCCyDvAeuHoiw==", "dev": true, "requires": { - "jest-matcher-utils": "^28.0.0", + "expect": "^28.0.0", "pretty-format": "^28.0.0" } }, @@ -22117,6 +22302,125 @@ "jest-util": "^28.1.3" } }, + "jest-extended": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/jest-extended/-/jest-extended-4.0.2.tgz", + "integrity": "sha512-FH7aaPgtGYHc9mRjriS0ZEHYM5/W69tLrFTIdzm+yJgeoCmmrSB/luSfMSqWP9O29QWHPEmJ4qmU6EwsZideog==", + "dev": true, + "requires": { + "jest-diff": "^29.0.0", + "jest-get-type": "^29.0.0" + }, + "dependencies": { + "@jest/schemas": { + "version": "29.6.3", + "resolved": "https://registry.npmjs.org/@jest/schemas/-/schemas-29.6.3.tgz", + "integrity": "sha512-mo5j5X+jIZmJQveBKeS/clAueipV7KgiX1vMgCxam1RNYiqE1w62n0/tJJnHtjW8ZHcQco5gY85jA3mi0L+nSA==", + "dev": true, + "requires": { + "@sinclair/typebox": "^0.27.8" + } + }, + "@sinclair/typebox": { + "version": "0.27.8", + "resolved": "https://registry.npmjs.org/@sinclair/typebox/-/typebox-0.27.8.tgz", + "integrity": "sha512-+Fj43pSMwJs4KRrH/938Uf+uAELIgVBmQzg/q1YG10djyfA3TnrU8N8XzqCh/okZdszqBQTZf96idMfE5lnwTA==", + "dev": true + }, + "ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "requires": { + "color-convert": "^2.0.1" + } + }, + "chalk": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", + "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "dev": true, + "requires": { + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" + } + }, + "color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dev": true, + "requires": { + "color-name": "~1.1.4" + } + }, + "color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "dev": true + }, + "diff-sequences": { + "version": "29.6.3", + "resolved": "https://registry.npmjs.org/diff-sequences/-/diff-sequences-29.6.3.tgz", + "integrity": "sha512-EjePK1srD3P08o2j4f0ExnylqRs5B9tJjcp9t1krH2qRi8CCdsYfwe9JgSLurFBWwq4uOlipzfk5fHNvwFKr8Q==", + "dev": true + }, + "has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true + }, + "jest-diff": { + "version": "29.7.0", + "resolved": "https://registry.npmjs.org/jest-diff/-/jest-diff-29.7.0.tgz", + "integrity": "sha512-LMIgiIrhigmPrs03JHpxUh2yISK3vLFPkAodPeo0+BuF7wA2FoQbkEg1u8gBYBThncu7e1oEDUfIXVuTqLRUjw==", + "dev": true, + "requires": { + "chalk": "^4.0.0", + "diff-sequences": "^29.6.3", + "jest-get-type": "^29.6.3", + "pretty-format": "^29.7.0" + } + }, + "jest-get-type": { + "version": "29.6.3", + "resolved": "https://registry.npmjs.org/jest-get-type/-/jest-get-type-29.6.3.tgz", + "integrity": "sha512-zrteXnqYxfQh7l5FHyL38jL39di8H8rHoecLH3JNxH3BwOrBsNeabdap5e0I23lD4HHI8W5VFBZqG4Eaq5LNcw==", + "dev": true + }, + "pretty-format": { + "version": "29.7.0", + "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-29.7.0.tgz", + "integrity": "sha512-Pdlw/oPxN+aXdmM9R00JVC9WVFoCLTKJvDVLgmJ+qAffBMxsV85l/Lu7sNx4zSzPyoL2euImuEwHhOXdEgNFZQ==", + "dev": true, + "requires": { + "@jest/schemas": "^29.6.3", + "ansi-styles": "^5.0.0", + "react-is": "^18.0.0" + }, + "dependencies": { + "ansi-styles": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-5.2.0.tgz", + "integrity": "sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==", + "dev": true + } + } + }, + "supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, + "requires": { + "has-flag": "^4.0.0" + } + } + } + }, "jest-get-type": { "version": "28.0.2", "resolved": "https://registry.npmjs.org/jest-get-type/-/jest-get-type-28.0.2.tgz", @@ -29282,9 +29586,9 @@ "integrity": "sha512-WZGXGstmCWgeevgTL54hrCuw1dyMQIzWy7ZfqRJfSmJZBwklI15egmQytFP6bPidmw3M8d5yEowl1niq4vmqZw==" }, "ts-jest": { - "version": "28.0.7", - "resolved": "https://registry.npmjs.org/ts-jest/-/ts-jest-28.0.7.tgz", - "integrity": "sha512-wWXCSmTwBVmdvWrOpYhal79bDpioDy4rTT+0vyUnE3ZzM7LOAAGG9NXwzkEL/a516rQEgnMmS/WKP9jBPCVJyA==", + "version": "28.0.8", + "resolved": "https://registry.npmjs.org/ts-jest/-/ts-jest-28.0.8.tgz", + "integrity": "sha512-5FaG0lXmRPzApix8oFG8RKjAz4ehtm8yMKOTy5HX3fY6W8kmvOrmcY0hKDElW52FJov+clhUbrKAqofnj4mXTg==", "dev": true, "requires": { "bs-logger": "0.x", diff --git a/package.json b/package.json index 8671448..9c10826 100644 --- a/package.json +++ b/package.json @@ -33,7 +33,7 @@ }, "homepage": "https://github.com/yeatmanlab/jsCAT#readme", "devDependencies": { - "@types/jest": "^28.1.6", + "@types/jest": "^28.1.8", "@types/lodash": "^4.14.182", "@types/seedrandom": "^3.0.2", "@typescript-eslint/eslint-plugin": "^5.30.7", @@ -41,8 +41,9 @@ "eslint": "^8.20.0", "eslint-config-prettier": "^8.5.0", "jest": "^28.1.3", + "jest-extended": "^4.0.2", "prettier": "^2.7.1", - "ts-jest": "^28.0.7", + "ts-jest": "^28.0.8", "tsdoc": "^0.0.4", "typescript": "^4.7.4" }, diff --git a/src/__tests__/cat.test.ts b/src/__tests__/cat.test.ts new file mode 100644 index 0000000..fb9cdf4 --- /dev/null +++ b/src/__tests__/cat.test.ts @@ -0,0 +1,233 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ +import { Cat } from '..'; +import { Stimulus } from '../type'; +import seedrandom from 'seedrandom'; +import { convertZeta } from '../corpus'; + +for (const format of ['symbolic', 'semantic'] as Array<'symbolic' | 'semantic'>) { + describe(`Cat with ${format} zeta`, () => { + let cat1: Cat, cat2: Cat, cat3: Cat, cat4: Cat, cat5: Cat, cat6: Cat, cat7: Cat, cat8: Cat; + let rng = seedrandom(); + + beforeEach(() => { + cat1 = new Cat(); + cat1.updateAbilityEstimate( + [ + convertZeta({ a: 2.225, b: -1.885, c: 0.21, d: 1 }, format), + convertZeta({ a: 1.174, b: -2.411, c: 0.212, d: 1 }, format), + convertZeta({ a: 2.104, b: -2.439, c: 0.192, d: 1 }, format), + ], + [1, 0, 1], + ); + + cat2 = new Cat(); + cat2.updateAbilityEstimate( + [ + convertZeta({ a: 1, b: -0.447, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: 2.869, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -0.469, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -0.576, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -1.43, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -1.607, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: 0.529, c: 0.5, d: 1 }, format), + ], + [0, 1, 0, 1, 1, 1, 1], + ); + cat3 = new Cat({ nStartItems: 0 }); + const randomSeed = 'test'; + rng = seedrandom(randomSeed); + cat4 = new Cat({ nStartItems: 0, itemSelect: 'RANDOM', randomSeed }); + cat5 = new Cat({ nStartItems: 1, startSelect: 'miDdle' }); + + cat6 = new Cat(); + cat6.updateAbilityEstimate( + [convertZeta({ a: 1, b: -4.0, c: 0.5, d: 1 }, format), convertZeta({ a: 1, b: -3.0, c: 0.5, d: 1 }, format)], + [0, 0], + ); + + cat7 = new Cat({ method: 'eap' }); + cat7.updateAbilityEstimate( + [convertZeta({ a: 1, b: -4.0, c: 0.5, d: 1 }, format), convertZeta({ a: 1, b: -3.0, c: 0.5, d: 1 }, format)], + [0, 0], + ); + + cat8 = new Cat({ nStartItems: 0, itemSelect: 'FIXED' }); + }); + + const s1: Stimulus = { difficulty: 0.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'looking' }; + const s2: Stimulus = { difficulty: 3.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'opaque' }; + const s3: Stimulus = { difficulty: 2, guessing: 0.5, discrimination: 1, slipping: 1, word: 'right' }; + const s4: Stimulus = { difficulty: -2.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'yes' }; + const s5: Stimulus = { difficulty: -1.8, guessing: 0.5, discrimination: 1, slipping: 1, word: 'mom' }; + const stimuli = [s1, s2, s3, s4, s5]; + + it('can update an ability estimate using only a single item and answer', () => { + const cat = new Cat(); + cat.updateAbilityEstimate(s1, 1); + expect(cat.nItems).toEqual(1); + expect(cat.theta).toBeCloseTo(4.572, 1); + }); + + it('constructs an adaptive test', () => { + expect(cat1.method).toBe('mle'); + expect(cat1.itemSelect).toBe('mfi'); + }); + + it('correctly updates ability estimate', () => { + expect(cat1.theta).toBeCloseTo(-1.642307, 1); + }); + + it('correctly updates ability estimate', () => { + expect(cat2.theta).toBeCloseTo(-1.272, 1); + }); + + it('correctly updates standard error of mean of ability estimate', () => { + expect(cat2.seMeasurement).toBeCloseTo(1.71, 1); + }); + + it('correctly counts number of items', () => { + expect(cat2.nItems).toEqual(7); + }); + + it('correctly updates answers', () => { + expect(cat2.resps).toEqual([0, 1, 0, 1, 1, 1, 1]); + }); + + it('correctly updates zetas', () => { + expect(cat2.zetas).toEqual([ + convertZeta({ a: 1, b: -0.447, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: 2.869, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -0.469, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -0.576, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -1.43, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -1.607, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: 0.529, c: 0.5, d: 1 }, format), + ]); + }); + + it.each` + deepCopy + ${true} + ${false} + `("correctly suggests the next item (closest method) with deepCopy='$deepCopy'", ({ deepCopy }) => { + const expected = { nextStimulus: s5, remainingStimuli: [s4, s1, s3, s2] }; + const received = cat1.findNextItem(stimuli, 'closest', deepCopy); + expect(received).toEqual(expected); + }); + + it.each` + deepCopy + ${true} + ${false} + `("correctly suggests the next item (mfi method) with deepCopy='$deepCopy'", ({ deepCopy }) => { + const expected = { nextStimulus: s1, remainingStimuli: [s4, s5, s3, s2] }; + const received = cat3.findNextItem(stimuli, 'MFI', deepCopy); + expect(received).toEqual(expected); + }); + + it.each` + deepCopy + ${true} + ${false} + `("correctly suggests the next item (middle method) with deepCopy='$deepCopy'", ({ deepCopy }) => { + const expected = { nextStimulus: s1, remainingStimuli: [s4, s5, s3, s2] }; + const received = cat5.findNextItem(stimuli, undefined, deepCopy); + expect(received).toEqual(expected); + }); + + it.each` + deepCopy + ${true} + ${false} + `("correctly suggests the next item (fixed method) with deepCopy='$deepCopy'", ({ deepCopy }) => { + expect(cat8.itemSelect).toBe('fixed'); + const expected = { nextStimulus: s1, remainingStimuli: [s2, s3, s4, s5] }; + const received = cat8.findNextItem(stimuli, undefined, deepCopy); + expect(received).toEqual(expected); + }); + + it.each` + deepCopy + ${true} + ${false} + `("correctly suggests the next item (random method) with deepCopy='$deepCopy'", ({ deepCopy }) => { + let received; + const stimuliSorted = stimuli.sort((a: Stimulus, b: Stimulus) => a.difficulty! - b.difficulty!); // ask + let index = Math.floor(rng() * stimuliSorted.length); + received = cat4.findNextItem(stimuliSorted, undefined, deepCopy); + expect(received.nextStimulus).toEqual(stimuliSorted[index]); + + for (let i = 0; i < 3; i++) { + const remainingStimuli = received.remainingStimuli; + index = Math.floor(rng() * remainingStimuli.length); + received = cat4.findNextItem(remainingStimuli, undefined, deepCopy); + expect(received.nextStimulus).toEqual(remainingStimuli[index]); + } + }); + + it('correctly updates ability estimate through MLE', () => { + expect(cat6.theta).toBeCloseTo(-6.0, 1); + }); + + it('correctly updates ability estimate through EAP', () => { + expect(cat7.theta).toBeCloseTo(0.25, 1); + }); + + it('should throw an error if zeta and answers do not have matching length', () => { + try { + cat7.updateAbilityEstimate( + [convertZeta({ a: 1, b: -4.0, c: 0.5, d: 1 }, format), convertZeta({ a: 1, b: -3.0, c: 0.5, d: 1 }, format)], + [0, 0, 0], + ); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + }); + + it('should throw an error if method is invalid', () => { + try { + new Cat({ method: 'coolMethod' }); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + + try { + cat7.updateAbilityEstimate( + [convertZeta({ a: 1, b: -4.0, c: 0.5, d: 1 }, format), convertZeta({ a: 1, b: -3.0, c: 0.5, d: 1 }, format)], + [0, 0], + 'coolMethod', + ); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + }); + + it('should throw an error if itemSelect is invalid', () => { + try { + new Cat({ itemSelect: 'coolMethod' }); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + + try { + cat7.findNextItem(stimuli, 'coolMethod'); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + }); + + it('should throw an error if startSelect is invalid', () => { + try { + new Cat({ startSelect: 'coolMethod' }); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + }); + + it('should return undefined if there are no input items', () => { + const cat10 = new Cat(); + const { nextStimulus } = cat10.findNextItem([]); + expect(nextStimulus).toBeUndefined(); + }); + }); +} diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts new file mode 100644 index 0000000..998499d --- /dev/null +++ b/src/__tests__/clowder.test.ts @@ -0,0 +1,588 @@ +import { Cat, Clowder, ClowderInput } from '..'; +import { MultiZetaStimulus, Zeta, ZetaCatMap } from '../type'; +import { defaultZeta } from '../corpus'; +import _uniq from 'lodash/uniq'; +import { StopAfterNItems, StopIfSEMeasurementBelowThreshold, StopOnSEMeasurementPlateau } from '../stopping'; + +const createStimulus = (id: string) => ({ + ...defaultZeta(), + id, + content: `Stimulus content ${id}`, +}); + +const createMultiZetaStimulus = (id: string, zetas: ZetaCatMap[]): MultiZetaStimulus => ({ + id, + content: `Multi-Zeta Stimulus content ${id}`, + zetas, +}); + +const createZetaCatMap = (catNames: string[], zeta: Zeta = defaultZeta()): ZetaCatMap => ({ + cats: catNames, + zeta, +}); + +describe('Clowder Class', () => { + let clowder: Clowder; + + beforeEach(() => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + cat2: { method: 'EAP', theta: -1.0 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1']), createZetaCatMap(['cat2'])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1']), createZetaCatMap(['cat2'])]), + createMultiZetaStimulus('2', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('3', [createZetaCatMap(['cat2'])]), + createMultiZetaStimulus('4', []), // Unvalidated item + ], + }; + clowder = new Clowder(clowderInput); + }); + + it('initializes with provided cats and corpora', () => { + expect(Object.keys(clowder.cats)).toContain('cat1'); + expect(clowder.remainingItems).toHaveLength(5); + expect(clowder.corpus).toHaveLength(5); + expect(clowder.seenItems).toHaveLength(0); + }); + + it('throws an error when given an invalid corpus', () => { + expect(() => { + const corpus: MultiZetaStimulus[] = [ + { + id: 'item1', + content: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + { cats: ['Model C'], zeta: { a: 1, b: 2, c: 0.3, d: 0.9 } }, + ], + }, + { + id: 'item2', + content: 'Item 2', + zetas: [{ cats: ['Model A', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], + }, + ]; + + new Clowder({ cats: { cat1: {} }, corpus }); + }).toThrowError('The cat names Model C are present in multiple corpora.'); + }); + + it('validates cat names', () => { + expect(() => { + clowder.updateCatAndGetNextItem({ + catToSelect: 'invalidCat', + }); + }).toThrowError('Invalid Cat name'); + }); + + it('updates ability estimates only for the named cats', () => { + const origTheta1 = clowder.cats.cat1.theta; + const origTheta2 = clowder.cats.cat2.theta; + + clowder.updateAbilityEstimates(['cat1'], createStimulus('1'), [0]); + + expect(clowder.cats.cat1.theta).not.toBe(origTheta1); + expect(clowder.cats.cat2.theta).toBe(origTheta2); + }); + + it('throws an error when updating ability estimates for an invalid cat', () => { + expect(() => clowder.updateAbilityEstimates(['invalidCatName'], createStimulus('1'), [0])).toThrowError( + 'Invalid Cat name. Expected one of cat1, cat2. Received invalidCatName.', + ); + }); + + it('should return undefined if no validated items remain and returnUndefinedOnExhaustion is true', () => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + ], + }; + + const clowder = new Clowder(clowderInput); + + // Use all the validated items for cat1 + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + items: [clowder.corpus[0], clowder.corpus[1]], + answers: [1, 1], + }); + + // Try to get another validated item for cat1 with returnUndefinedOnExhaustion set to true + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + returnUndefinedOnExhaustion: true, + }); + expect(clowder.stoppingReason).toBe('No validated items remaining for specified catToSelect'); + expect(nextItem).toBeUndefined(); + }); + + it('should return an item from missing if catToSelect is "unvalidated", no unvalidated items remain, and returnUndefinedOnExhaustion is false', () => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), // Validated item + createMultiZetaStimulus('1', [createZetaCatMap([])]), // Unvalidated item + createMultiZetaStimulus('2', [createZetaCatMap(['cat1'])]), // Unvalidated item + ], + }; + + const clowder = new Clowder(clowderInput); + + // Exhaust the unvalidated items + clowder.updateCatAndGetNextItem({ + catToSelect: 'unvalidated', + items: [clowder.corpus[1]], + answers: [1], + }); + + const nDraws = 50; + // Simulate sDraws unvalidated items being selected + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for (const _ of Array(nDraws).fill(0)) { + // Attempt to get another unvalidated item with returnUndefinedOnExhaustion set to false + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'unvalidated', + returnUndefinedOnExhaustion: false, + }); + + // Should return a validated item since no unvalidated items remain + expect(['0', '2']).toContain(nextItem?.id); // Item ID should match any of the items for cat2 + } + }); + + it.each` + property + ${'theta'} + ${'seMeasurement'} + ${'nItems'} + ${'resps'} + ${'zetas'} + `("accesses the '$property' property of each cat", ({ property }) => { + clowder.updateAbilityEstimates(['cat1'], createStimulus('1'), [0]); + clowder.updateAbilityEstimates(['cat2'], createStimulus('1'), [1]); + const expected = { + cat1: clowder.cats['cat1'][property as keyof Cat], + cat2: clowder.cats['cat2'][property as keyof Cat], + }; + expect(clowder[property as keyof Clowder]).toEqual(expected); + }); + + it('throws an error if items and answers have mismatched lengths', () => { + expect(() => { + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + items: createMultiZetaStimulus('1', [createZetaCatMap(['cat1']), createZetaCatMap(['cat2'])]), + answers: [1, 0], // Mismatched length + }); + }).toThrow('Previous items and answers must have the same length.'); + }); + + it('throws an error if catToSelect is invalid', () => { + expect(() => { + clowder.updateCatAndGetNextItem({ + catToSelect: 'invalidCatName', + }); + }).toThrow('Invalid Cat name. Expected one of cat1, cat2, unvalidated. Received invalidCatName.'); + }); + + it('throws an error if any of catsToUpdate is invalid', () => { + expect(() => { + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['invalidCatName', 'cat2'], + }); + }).toThrow('Invalid Cat name. Expected one of cat1, cat2. Received invalidCatName.'); + }); + + it('updates seen and remaining items', () => { + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat2', + catsToUpdate: ['cat1', 'cat2'], + items: [clowder.corpus[0], clowder.corpus[1], clowder.corpus[2]], + answers: [1, 1, 1], + }); + + expect(clowder.seenItems).toHaveLength(3); + expect(clowder.remainingItems).toHaveLength(2); + }); + + it('should select an item that has not yet been seen', () => { + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat2', + catsToUpdate: ['cat1', 'cat2'], + items: [clowder.corpus[0], clowder.corpus[1], clowder.corpus[2]], + answers: [1, 1, 1], + }); + + expect([clowder.corpus[3], clowder.corpus[4]]).toContainEqual(nextItem); // Third validated stimulus + }); + + it('should select a validated item if validated items are present and randomlySelectUnvalidated is false', () => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + cat2: { method: 'EAP', theta: -1.0 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat2'])]), + ], + }; + const clowder = new Clowder(clowderInput); + + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + randomlySelectUnvalidated: false, + }); + expect(nextItem?.id).toMatch(/^(0|1)$/); + }); + + it('should return an item from missing if no validated items remain and returnUndefinedOnExhaustion is false', () => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + cat2: { method: 'EAP', theta: -1.0 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat2'])]), // Validated for cat2 + createMultiZetaStimulus('1', [createZetaCatMap(['cat2'])]), // Validated for cat2 + createMultiZetaStimulus('2', [createZetaCatMap([])]), // Unvalidated + ], + }; + + const clowder = new Clowder(clowderInput); + + // Should return an item from `missing`, which are items validated for cat2 or unvalidated + const nDraws = 50; + // Simulate sDraws unvalidated items being selected + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for (const _ of Array(nDraws).fill(0)) { + // Attempt to select an item for cat1, which has no validated items in the corpus + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + returnUndefinedOnExhaustion: false, // Ensure fallback is enabled + }); + expect(['0', '1', '2']).toContain(nextItem?.id); // Item ID should match any of the items for cat2 + } + }); + + it('should select an unvalidated item if catToSelect is "unvalidated"', () => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap([])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('2', [createZetaCatMap([])]), + createMultiZetaStimulus('3', [createZetaCatMap(['cat1'])]), + ], + }; + + const clowder = new Clowder(clowderInput); + + const nDraws = 50; + // Simulate sDraws unvalidated items being selected + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for (const _ of Array(nDraws).fill(0)) { + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'unvalidated', + }); + + expect(['0', '2']).toContain(nextItem?.id); + } + }); + + it('should not update cats with items that do not have parameters for that cat', () => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + cat2: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('2', [createZetaCatMap(['cat2'])]), + createMultiZetaStimulus('3', [createZetaCatMap(['cat2'])]), + ], + }; + + const clowder = new Clowder(clowderInput); + + clowder.updateCatAndGetNextItem({ + catsToUpdate: ['cat1', 'cat2'], + items: clowder.corpus, + answers: [1, 1, 1, 1], + catToSelect: 'unvalidated', + }); + + expect(clowder.nItems.cat1).toBe(2); + expect(clowder.nItems.cat2).toBe(2); + }); + + it('should not update any cats if only unvalidated items have been seen', () => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap([])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('2', [createZetaCatMap([])]), + createMultiZetaStimulus('3', [createZetaCatMap(['cat1'])]), + ], + }; + + const clowder = new Clowder(clowderInput); + + clowder.updateCatAndGetNextItem({ + catsToUpdate: ['cat1'], + items: [clowder.corpus[0], clowder.corpus[2]], + answers: [1, 1], + catToSelect: 'unvalidated', + }); + + expect(clowder.nItems.cat1).toBe(0); + }); + + it('should return undefined for next item if catToSelect = "unvalidated" and no unvalidated items remain', () => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap([])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('2', [createZetaCatMap([])]), + createMultiZetaStimulus('3', [createZetaCatMap(['cat1'])]), + ], + }; + + const clowder = new Clowder(clowderInput); + + const nextItem = clowder.updateCatAndGetNextItem({ + catsToUpdate: ['cat1'], + items: [clowder.corpus[0], clowder.corpus[2]], + answers: [1, 1], + catToSelect: 'unvalidated', + }); + expect(clowder.stoppingReason).toBe('No unvalidated items remaining'); + expect(nextItem).toBeUndefined(); + }); + + it('should correctly update ability estimates during the updateCatAndGetNextItem method', () => { + const originalTheta = clowder.cats.cat1.theta; + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + items: [clowder.corpus[0]], + answers: [1], + }); + expect(clowder.cats.cat1.theta).not.toBe(originalTheta); + }); + + it('should randomly choose between validated and unvalidated items if randomlySelectUnvalidated is true', () => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), // Validated item + createMultiZetaStimulus('1', [createZetaCatMap([])]), // Unvalidated item + createMultiZetaStimulus('2', [createZetaCatMap([])]), // Unvalidated item + createMultiZetaStimulus('3', [createZetaCatMap([])]), // Validated item + ], + randomSeed: 'randomSeed', + }; + const clowder = new Clowder(clowderInput); + + const nextItems = Array(20) + .fill('-1') + .map(() => { + return clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + randomlySelectUnvalidated: true, + }); + }); + + const itemsId = nextItems.map((item) => item?.id); + + expect(nextItems).toBeDefined(); + expect(_uniq(itemsId)).toEqual(expect.arrayContaining(['0', '1', '2', '3'])); // Could be validated or unvalidated + }); + + it('should return undefined if no more items remain', () => { + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + items: clowder.remainingItems, + answers: [1, 0, 1, 1, 0], // Exhaust all items + }); + + expect(nextItem).toBeUndefined(); + }); + + it('can receive one item and answer as an input', () => { + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + items: clowder.corpus[0], + answers: 1, + }); + expect(nextItem).toBeDefined(); + }); + + it('can receive only one catToUpdate', () => { + const originalTheta = clowder.cats.cat1.theta; + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: 'cat1', + items: clowder.corpus[0], + answers: 1, + }); + expect(nextItem).toBeDefined(); + expect(clowder.cats.cat1.theta).not.toBe(originalTheta); + }); + + it('should update early stopping conditions based on number of items presented', () => { + const earlyStopping = new StopOnSEMeasurementPlateau({ + patience: { cat1: 2 }, // Requires 2 items to check for plateau + tolerance: { cat1: 0.05 }, // SE change tolerance + }); + + const clowder = new Clowder({ + cats: { cat1: { method: 'MLE', theta: 0.5 } }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + ], + earlyStopping, + }); + + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + items: [clowder.corpus[0]], + answers: [1], + }); + + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + items: [clowder.corpus[1]], + answers: [1], + }); + + expect(clowder.earlyStopping?.earlyStop).toBe(true); // Should stop after 2 items + expect(clowder.stoppingReason).toBe('Early stopping'); + expect(nextItem).toBe(undefined); // Expect undefined after early stopping + }); +}); + +describe('Clowder Early Stopping', () => { + let clowder: Clowder; + + beforeEach(() => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + ], + }; + clowder = new Clowder(clowderInput); + }); + + it('should trigger early stopping after required number of items', () => { + const earlyStopping = new StopAfterNItems({ + requiredItems: { cat1: 2 }, // Stop after 2 items + }); + + clowder = new Clowder({ + cats: { cat1: { method: 'MLE', theta: 0.5 } }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), // This item should trigger early stopping + createMultiZetaStimulus('2', [createZetaCatMap(['cat1'])]), + ], + earlyStopping, + }); + + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + items: [clowder.corpus[0]], + answers: [1], + }); + + expect(clowder.earlyStopping?.earlyStop).toBe(false); + + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + items: [clowder.corpus[1]], + answers: [1], + }); + + expect(clowder.earlyStopping?.earlyStop).toBe(true); // Early stop should be triggered after 2 items + expect(nextItem).toBe(undefined); // No further items should be selected + expect(clowder.stoppingReason).toBe('Early stopping'); + }); + + it('should handle StopIfSEMeasurementBelowThreshold condition', () => { + const earlyStopping = new StopIfSEMeasurementBelowThreshold({ + seMeasurementThreshold: { cat1: 0.2 }, // Stop if SE drops below 0.2 + patience: { cat1: 2 }, + tolerance: { cat1: 0.01 }, + }); + + const zetaMap = createZetaCatMap(['cat1'], { + a: 6, + b: 6, + c: 0, + d: 1, + }); + + const corpus = [ + createMultiZetaStimulus('0', [zetaMap]), + createMultiZetaStimulus('1', [zetaMap]), + createMultiZetaStimulus('2', [zetaMap]), // Here the SE measurement drops below threshold + createMultiZetaStimulus('3', [zetaMap]), // And here, early stopping should be triggered because it has been below threshold for 2 items + ]; + + clowder = new Clowder({ + cats: { cat1: { method: 'MLE', theta: 0.5 } }, + corpus, + earlyStopping, + }); + + for (const item of corpus) { + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + items: [item], + answers: [1], + }); + + if (item.id === '3') { + expect(clowder.earlyStopping?.earlyStop).toBe(true); // Should stop after SE drops below threshold + expect(clowder.stoppingReason).toBe('Early stopping'); + expect(nextItem).toBe(undefined); // No further items should be selected + } else { + expect(clowder.earlyStopping?.earlyStop).toBe(false); + expect(nextItem).toBeDefined(); + } + } + }); +}); diff --git a/src/__tests__/corpus.test.ts b/src/__tests__/corpus.test.ts new file mode 100644 index 0000000..d9d79d8 --- /dev/null +++ b/src/__tests__/corpus.test.ts @@ -0,0 +1,431 @@ +import { MultiZetaStimulus, Stimulus, Zeta } from '../type'; +import { + validateZetaParams, + ZETA_KEY_MAP, + defaultZeta, + fillZetaDefaults, + convertZeta, + checkNoDuplicateCatNames, + filterItemsByCatParameterAvailability, +} from '../corpus'; +import { prepareClowderCorpus } from '..'; +import _omit from 'lodash/omit'; + +describe('validateZetaParams', () => { + it('throws an error when providing both a and discrimination', () => { + expect(() => validateZetaParams({ a: 1, discrimination: 1 })).toThrow( + 'This item has both an `a` key and `discrimination` key. Please provide only one.', + ); + }); + + it('throws an error when providing both b and difficulty', () => { + expect(() => validateZetaParams({ b: 1, difficulty: 1 })).toThrow( + 'This item has both a `b` key and `difficulty` key. Please provide only one.', + ); + }); + + it('throws an error when providing both c and guessing', () => { + expect(() => validateZetaParams({ c: 1, guessing: 1 })).toThrow( + 'This item has both a `c` key and `guessing` key. Please provide only one.', + ); + }); + + it('throws an error when providing both d and slipping', () => { + expect(() => validateZetaParams({ d: 1, slipping: 1 })).toThrow( + 'This item has both a `d` key and `slipping` key. Please provide only one.', + ); + }); + + it('throws an error when requiring all keys and missing one', () => { + for (const key of ['a', 'b', 'c', 'd'] as (keyof typeof ZETA_KEY_MAP)[]) { + const semanticKey = ZETA_KEY_MAP[key]; + const zeta = _omit(defaultZeta('symbolic'), [key]); + + expect(() => validateZetaParams(zeta, true)).toThrow( + `This item is missing the key \`${String(key)}\` or \`${semanticKey}\`.`, + ); + } + }); +}); + +describe('fillZetaDefaults', () => { + it('fills in default values for missing keys', () => { + const zeta: Zeta = { + difficulty: 1, + guessing: 0.5, + }; + + const filledZeta = fillZetaDefaults(zeta, 'semantic'); + + expect(filledZeta).toEqual({ + discrimination: 1, + difficulty: 1, + guessing: 0.5, + slipping: 1, + }); + }); + + it('does not modify the input object when no missing keys', () => { + const zeta: Zeta = { + a: 5, + b: 5, + c: 5, + d: 5, + }; + + const filledZeta = fillZetaDefaults(zeta, 'symbolic'); + + expect(filledZeta).toEqual(zeta); + }); + + it('converts to semantic format when desired', () => { + const zeta: Zeta = { + a: 5, + b: 5, + }; + + const filledZeta = fillZetaDefaults(zeta, 'semantic'); + + expect(filledZeta).toEqual({ + difficulty: 5, + discrimination: 5, + guessing: 0, + slipping: 1, + }); + }); + + it('converts to symbolic format when desired', () => { + const zeta: Zeta = { + difficulty: 5, + discrimination: 5, + }; + + const filledZeta = fillZetaDefaults(zeta, 'symbolic'); + + expect(filledZeta).toEqual({ + a: 5, + b: 5, + c: 0, + d: 1, + }); + }); +}); + +describe('convertZeta', () => { + it('converts from symbolic format to semantic format', () => { + const zeta: Zeta = { + a: 1, + b: 2, + c: 3, + d: 4, + }; + + const convertedZeta = convertZeta(zeta, 'semantic'); + + expect(convertedZeta).toEqual({ + discrimination: 1, + difficulty: 2, + guessing: 3, + slipping: 4, + }); + }); + + it('converts from semantic format to symbolic format', () => { + const zeta: Zeta = { + discrimination: 1, + difficulty: 2, + guessing: 3, + slipping: 4, + }; + + const convertedZeta = convertZeta(zeta, 'symbolic'); + + expect(convertedZeta).toEqual({ + a: 1, + b: 2, + c: 3, + d: 4, + }); + }); + + it('throws an error when converting from an unsupported format', () => { + const zeta: Zeta = { + a: 1, + b: 2, + c: 3, + d: 4, + }; + + expect(() => convertZeta(zeta, 'unsupported' as 'symbolic')).toThrow( + "Invalid desired format. Expected 'symbolic' or'semantic'. Received unsupported instead.", + ); + }); + + it('does not modify other keys when converting', () => { + const zeta: Stimulus = { + a: 1, + b: 2, + c: 3, + d: 4, + key1: 5, + key2: 6, + key3: 7, + key4: 8, + }; + + const convertedZeta = convertZeta(zeta, 'semantic'); + + expect(convertedZeta).toEqual({ + discrimination: 1, + difficulty: 2, + guessing: 3, + slipping: 4, + key1: 5, + key2: 6, + key3: 7, + key4: 8, + }); + }); + + it('converts only existing keys', () => { + const zeta: Zeta = { + a: 1, + b: 2, + }; + + const convertedZeta = convertZeta(zeta, 'semantic'); + + expect(convertedZeta).toEqual({ + discrimination: 1, + difficulty: 2, + }); + }); +}); + +describe('checkNoDuplicateCatNames', () => { + it('should throw an error when a cat name is present in multiple zetas', () => { + const corpus: MultiZetaStimulus[] = [ + { + stimulus: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + { cats: ['Model C'], zeta: { a: 1, b: 2, c: 0.3, d: 0.9 } }, + ], + }, + { + stimulus: 'Item 2', + zetas: [{ cats: ['Model A', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], + }, + ]; + expect(() => checkNoDuplicateCatNames(corpus)).toThrowError( + 'The cat names Model C are present in multiple corpora.', + ); + }); + + it('should not throw an error when a cat name is not present in multiple corpora', () => { + const items: MultiZetaStimulus[] = [ + { + stimulus: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + ], + }, + { + stimulus: 'Item 2', + zetas: [{ cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], + }, + ]; + + expect(() => checkNoDuplicateCatNames(items)).not.toThrowError(); + }); + + it('should handle an empty corpus without throwing an error', () => { + const emptyCorpus: MultiZetaStimulus[] = []; + + expect(() => checkNoDuplicateCatNames(emptyCorpus)).not.toThrowError(); + }); +}); + +describe('filterItemsByCatParameterAvailability', () => { + it('returns an empty "available" array when no items match the catname', () => { + const items: MultiZetaStimulus[] = [ + { + stimulus: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + ], + }, + { + stimulus: 'Item 2', + zetas: [{ cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], + }, + ]; + + const result = filterItemsByCatParameterAvailability(items, 'Model D'); + + expect(result.available).toEqual([]); + expect(result.missing).toEqual(items); + }); + + it('returns empty missing array when all items match the catname', () => { + const items: MultiZetaStimulus[] = [ + { + stimulus: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model A'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + ], + }, + { + stimulus: 'Item 2', + zetas: [ + { cats: ['Model A', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }, + { cats: ['Model A'], zeta: { a: 3, b: 0.9, c: 0.4, d: 0.99 } }, + ], + }, + ]; + + const result = filterItemsByCatParameterAvailability(items, 'Model A'); + + expect(result.missing).toEqual([]); + expect(result.available).toEqual(items); + }); + + it('separates items based on matching catnames', () => { + const items: MultiZetaStimulus[] = [ + { + stimulus: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + ], + }, + { + stimulus: 'Item 2', + zetas: [{ cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], + }, + { + stimulus: 'Item 3', + zetas: [{ cats: ['Model A'], zeta: { a: 3, b: 0.9, c: 0.4, d: 0.99 } }], + }, + ]; + + const result = filterItemsByCatParameterAvailability(items, 'Model A'); + + // Assert + expect(result.available.length).toBe(2); + expect(result.available[0].stimulus).toBe('Item 1'); + expect(result.available[1].stimulus).toBe('Item 3'); + expect(result.missing.length).toBe(1); + expect(result.missing[0].stimulus).toBe('Item 2'); + }); +}); + +describe('prepareClowderCorpus', () => { + it('converts a Stimulus array to a MultiZetaStimulus array with symbolic format', () => { + const items: Stimulus[] = [ + { + 'cat1.a': 1, + 'cat1.b': 2, + 'cat1.c': 3, + 'cat1.d': 4, + 'foo.a': 5, + 'foo.b': 6, + 'foo.c': 7, + 'foo.d': 8, + stimulus: 'stim0', + type: 'jspsychHtmlMultiResponse', + }, + ]; + + const result = prepareClowderCorpus(items, ['cat1', 'foo'], '.'); + + expect(result).toEqual([ + { + stimulus: 'stim0', + type: 'jspsychHtmlMultiResponse', + zetas: [ + { + cats: ['cat1'], + zeta: { a: 1, b: 2, c: 3, d: 4 }, + }, + { + cats: ['foo'], + zeta: { a: 5, b: 6, c: 7, d: 8 }, + }, + ], + }, + ]); + }); + + it('converts a Stimulus array to a MultiZetaStimulus array with semantic format', () => { + const items: Stimulus[] = [ + { + 'cat1.a': 1, + 'cat1.b': 2, + 'cat1.c': 3, + 'cat1.d': 4, + 'foo.a': 5, + 'foo.b': 6, + 'foo.c': 7, + 'foo.d': 8, + stimulus: 'stim0', + type: 'jspsychHtmlMultiResponse', + }, + ]; + + const result = prepareClowderCorpus(items, ['cat1', 'foo'], '.', 'semantic'); + + expect(result).toEqual([ + { + stimulus: 'stim0', + type: 'jspsychHtmlMultiResponse', + zetas: [ + { + cats: ['cat1'], + zeta: { discrimination: 1, difficulty: 2, guessing: 3, slipping: 4 }, + }, + { + cats: ['foo'], + zeta: { discrimination: 5, difficulty: 6, guessing: 7, slipping: 8 }, + }, + ], + }, + ]); + }); + + it('handles cases with different delimiters', () => { + const items: Stimulus[] = [ + { + cat1_a: 1, + cat1_b: 2, + foo_a: 5, + foo_b: 6, + stimulus: 'stim1', + type: 'jspsychHtmlMultiResponse', + }, + ]; + + const result = prepareClowderCorpus(items, ['cat1', 'foo'], '_', 'symbolic'); + + expect(result).toEqual([ + { + stimulus: 'stim1', + type: 'jspsychHtmlMultiResponse', + zetas: [ + { + cats: ['cat1'], + zeta: { a: 1, b: 2 }, + }, + { + cats: ['foo'], + zeta: { a: 5, b: 6 }, + }, + ], + }, + ]); + }); +}); diff --git a/src/__tests__/index.test.ts b/src/__tests__/index.test.ts deleted file mode 100644 index 1194090..0000000 --- a/src/__tests__/index.test.ts +++ /dev/null @@ -1,207 +0,0 @@ -import { Cat } from '../index'; -import { Stimulus } from '../type'; -import seedrandom from 'seedrandom'; - -describe('Cat', () => { - let cat1: Cat, cat2: Cat, cat3: Cat, cat4: Cat, cat5: Cat, cat6: Cat, cat7: Cat, cat8: Cat; - let rng = seedrandom(); - beforeEach(() => { - cat1 = new Cat(); - cat1.updateAbilityEstimate( - [ - { a: 2.225, b: -1.885, c: 0.21, d: 1 }, - { a: 1.174, b: -2.411, c: 0.212, d: 1 }, - { a: 2.104, b: -2.439, c: 0.192, d: 1 }, - ], - [1, 0, 1], - ); - - cat2 = new Cat(); - cat2.updateAbilityEstimate( - [ - { a: 1, b: -0.447, c: 0.5, d: 1 }, - { a: 1, b: 2.869, c: 0.5, d: 1 }, - { a: 1, b: -0.469, c: 0.5, d: 1 }, - { a: 1, b: -0.576, c: 0.5, d: 1 }, - { a: 1, b: -1.43, c: 0.5, d: 1 }, - { a: 1, b: -1.607, c: 0.5, d: 1 }, - { a: 1, b: 0.529, c: 0.5, d: 1 }, - ], - [0, 1, 0, 1, 1, 1, 1], - ); - cat3 = new Cat({ nStartItems: 0 }); - const randomSeed = 'test'; - rng = seedrandom(randomSeed); - cat4 = new Cat({ nStartItems: 0, itemSelect: 'RANDOM', randomSeed }); - cat5 = new Cat({ nStartItems: 1, startSelect: 'miDdle' }); - - cat6 = new Cat(); - cat6.updateAbilityEstimate( - [ - { a: 1, b: -4.0, c: 0.5, d: 1 }, - { a: 1, b: -3.0, c: 0.5, d: 1 }, - ], - [0, 0], - ); - - cat7 = new Cat({ method: 'eap' }); - cat7.updateAbilityEstimate( - [ - { a: 1, b: -4.0, c: 0.5, d: 1 }, - { a: 1, b: -3.0, c: 0.5, d: 1 }, - ], - [0, 0], - ); - - cat8 = new Cat({ nStartItems: 0, itemSelect: 'FIXED' }); - }); - - const s1: Stimulus = { difficulty: 0.5, c: 0.5, word: 'looking' }; - const s2: Stimulus = { difficulty: 3.5, c: 0.5, word: 'opaque' }; - const s3: Stimulus = { difficulty: 2, c: 0.5, word: 'right' }; - const s4: Stimulus = { difficulty: -2.5, c: 0.5, word: 'yes' }; - const s5: Stimulus = { difficulty: -1.8, c: 0.5, word: 'mom' }; - const stimuli = [s1, s2, s3, s4, s5]; - - it('constructs an adaptive test', () => { - expect(cat1.method).toBe('mle'); - expect(cat1.itemSelect).toBe('mfi'); - }); - - it('correctly updates ability estimate', () => { - expect(cat1.theta).toBeCloseTo(-1.642307, 1); - }); - - it('correctly updates ability estimate', () => { - expect(cat2.theta).toBeCloseTo(-1.272, 1); - }); - - it('correctly updates standard error of mean of ability estimate', () => { - expect(cat2.seMeasurement).toBeCloseTo(1.71, 1); - }); - - it('correctly counts number of items', () => { - expect(cat2.nItems).toEqual(7); - }); - - it('correctly updates answers', () => { - expect(cat2.resps).toEqual([0, 1, 0, 1, 1, 1, 1]); - }); - - it('correctly updates zatas', () => { - expect(cat2.zetas).toEqual([ - { a: 1, b: -0.447, c: 0.5, d: 1 }, - { a: 1, b: 2.869, c: 0.5, d: 1 }, - { a: 1, b: -0.469, c: 0.5, d: 1 }, - { a: 1, b: -0.576, c: 0.5, d: 1 }, - { a: 1, b: -1.43, c: 0.5, d: 1 }, - { a: 1, b: -1.607, c: 0.5, d: 1 }, - { a: 1, b: 0.529, c: 0.5, d: 1 }, - ]); - }); - - it('correctly suggests the next item (closest method)', () => { - const expected = { nextStimulus: s5, remainingStimuli: [s4, s1, s3, s2] }; - const received = cat1.findNextItem(stimuli, 'closest'); - expect(received).toEqual(expected); - }); - - it('correctly suggests the next item (mfi method)', () => { - const expected = { nextStimulus: s1, remainingStimuli: [s4, s5, s3, s2] }; - const received = cat3.findNextItem(stimuli, 'MFI'); - expect(received).toEqual(expected); - }); - - it('correctly suggests the next item (middle method)', () => { - const expected = { nextStimulus: s1, remainingStimuli: [s4, s5, s3, s2] }; - const received = cat5.findNextItem(stimuli); - expect(received).toEqual(expected); - }); - - it('correctly suggests the next item (fixed method)', () => { - expect(cat8.itemSelect).toBe('fixed'); - const expected = { nextStimulus: s1, remainingStimuli: [s2, s3, s4, s5] }; - const received = cat8.findNextItem(stimuli); - expect(received).toEqual(expected); - }); - - it('correctly suggests the next item (random method)', () => { - let received; - const stimuliSorted = stimuli.sort((a: Stimulus, b: Stimulus) => a.difficulty - b.difficulty); - let index = Math.floor(rng() * stimuliSorted.length); - received = cat4.findNextItem(stimuliSorted); - expect(received.nextStimulus).toEqual(stimuliSorted[index]); - - for (let i = 0; i < 3; i++) { - const remainingStimuli = received.remainingStimuli; - index = Math.floor(rng() * remainingStimuli.length); - received = cat4.findNextItem(remainingStimuli); - expect(received.nextStimulus).toEqual(remainingStimuli[index]); - } - }); - - it('correctly updates ability estimate through MLE', () => { - expect(cat6.theta).toBeCloseTo(-6.0, 1); - }); - - it('correctly updates ability estimate through EAP', () => { - expect(cat7.theta).toBeCloseTo(0.25, 1); - }); - - it('should throw a error if zeta and answers do not have matching length', () => { - try { - cat7.updateAbilityEstimate( - [ - { a: 1, b: -4.0, c: 0.5, d: 1 }, - { a: 1, b: -3.0, c: 0.5, d: 1 }, - ], - [0, 0, 0], - ); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - }); - - it('should throw a error if method is invalid', () => { - try { - new Cat({ method: 'coolMethod' }); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - - try { - cat7.updateAbilityEstimate( - [ - { a: 1, b: -4.0, c: 0.5, d: 1 }, - { a: 1, b: -3.0, c: 0.5, d: 1 }, - ], - [0, 0], - 'coolMethod', - ); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - }); - - it('should throw a error if itemSelect is invalid', () => { - try { - new Cat({ itemSelect: 'coolMethod' }); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - - try { - cat7.findNextItem(stimuli, 'coolMethod'); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - }); - - it('should throw a error if startSelect is invalid', () => { - try { - new Cat({ startSelect: 'coolMethod' }); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - }); -}); diff --git a/src/__tests__/stopping.test.ts b/src/__tests__/stopping.test.ts new file mode 100644 index 0000000..0676d33 --- /dev/null +++ b/src/__tests__/stopping.test.ts @@ -0,0 +1,848 @@ +import { Cat } from '..'; +import { CatMap } from '../type'; +import { EarlyStopping, StopAfterNItems, StopIfSEMeasurementBelowThreshold, StopOnSEMeasurementPlateau } from '../'; +import { + StopAfterNItemsInput, + StopIfSEMeasurementBelowThresholdInput, + StopOnSEMeasurementPlateauInput, +} from '../stopping'; +import { toBeBoolean } from 'jest-extended'; +expect.extend({ toBeBoolean }); + +/* eslint-disable @typescript-eslint/no-explicit-any */ +type Class = new (...args: any[]) => T; + +const testLogicalOperationValidation = ( + stoppingClass: Class, + input: StopAfterNItemsInput | StopIfSEMeasurementBelowThresholdInput | StopOnSEMeasurementPlateauInput, +) => { + expect(() => new stoppingClass(input)).toThrowError( + `Invalid logical operation. Expected "and", "or", or "only". Received "${input.logicalOperation}"`, + ); +}; + +const testInstantiation = ( + earlyStopping: EarlyStopping, + input: StopAfterNItemsInput | StopIfSEMeasurementBelowThresholdInput | StopOnSEMeasurementPlateauInput, +) => { + if (earlyStopping instanceof StopAfterNItems) { + expect(earlyStopping.requiredItems).toEqual((input as StopAfterNItems).requiredItems ?? {}); + } + + if ( + earlyStopping instanceof StopOnSEMeasurementPlateau || + earlyStopping instanceof StopIfSEMeasurementBelowThreshold + ) { + expect(earlyStopping.patience).toEqual((input as StopOnSEMeasurementPlateauInput).patience ?? {}); + expect(earlyStopping.tolerance).toEqual((input as StopOnSEMeasurementPlateauInput).tolerance ?? {}); + } + + if (earlyStopping instanceof StopIfSEMeasurementBelowThreshold) { + expect(earlyStopping.seMeasurementThreshold).toEqual( + (input as StopIfSEMeasurementBelowThresholdInput).seMeasurementThreshold ?? {}, + ); + } + + expect(earlyStopping.logicalOperation).toBe(input.logicalOperation?.toLowerCase() ?? 'or'); + expect(earlyStopping.earlyStop).toBeBoolean(); +}; + +const testInternalState = (earlyStopping: EarlyStopping) => { + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 2, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); + expect(earlyStopping.nItems.cat1).toBe(1); + expect(earlyStopping.seMeasurements.cat1).toEqual([0.5]); + expect(earlyStopping.nItems.cat2).toBe(1); + expect(earlyStopping.seMeasurements.cat2).toEqual([0.3]); + + earlyStopping.update(updates[1]); + expect(earlyStopping.nItems.cat1).toBe(2); + expect(earlyStopping.seMeasurements.cat1).toEqual([0.5, 0.5]); + expect(earlyStopping.nItems.cat2).toBe(2); + expect(earlyStopping.seMeasurements.cat2).toEqual([0.3, 0.3]); +}; + +describe.each` + logicalOperation + ${'and'} + ${'or'} +`("StopOnSEMeasurementPlateau (with logicalOperation='$logicalOperation'", ({ logicalOperation }) => { + let earlyStopping: StopOnSEMeasurementPlateau; + let input: StopOnSEMeasurementPlateauInput; + + beforeEach(() => { + input = { + patience: { cat1: 2, cat2: 3 }, + tolerance: { cat1: 0.01, cat2: 0.02 }, + logicalOperation, + }; + earlyStopping = new StopOnSEMeasurementPlateau(input); + }); + + it('instantiates with input parameters', () => testInstantiation(earlyStopping, input)); + it('validates input', () => + testLogicalOperationValidation(StopOnSEMeasurementPlateau, { ...input, logicalOperation: 'invalid' as 'and' })); + it('updates internal state when new measurements are added', () => testInternalState(earlyStopping)); + + it('stops when the seMeasurement has plateaued', () => { + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }, + { + // cat1 should trigger stopping if logicalOperator === 'or', because + // seMeasurement plateaued over the patience period of 2 items + cat1: { + nItems: 2, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 3, + seMeasurement: 0.5, + } as Cat, + // cat2 should trigger stopping if logicalOperator === 'and', because + // seMeasurement plateaued over the patience period of 3 items, and the + // cat1 criterion passed last update + cat2: { + nItems: 3, + seMeasurement: 0.3, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); + expect(earlyStopping.earlyStop).toBe(false); + + if (earlyStopping.logicalOperation === 'or') { + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(true); + } else { + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(true); + } + }); + + it('does not stop when the seMeasurement has not plateaued', () => { + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 100, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 100, + } as Cat, + }, + { + cat1: { + nItems: 2, + seMeasurement: 10, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 10, + } as Cat, + }, + { + cat1: { + nItems: 3, + seMeasurement: 1, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 1, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(false); + }); + + it('waits for `patience` items to monitor the seMeasurement plateau', () => { + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 100, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 2, + seMeasurement: 10, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 3, + seMeasurement: 10, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.3, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(true); + }); + + it('triggers early stopping when within tolerance', () => { + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 10, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.4, + } as Cat, + }, + { + cat1: { + nItems: 2, + seMeasurement: 1, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.395, + } as Cat, + }, + { + cat1: { + nItems: 3, + seMeasurement: 0.99, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.39, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(true); + }); +}); + +describe.each` + logicalOperation + ${'and'} + ${'or'} +`("StopAfterNItems (with logicalOperation='$logicalOperation'", ({ logicalOperation }) => { + let earlyStopping: StopAfterNItems; + let input: StopAfterNItemsInput; + + beforeEach(() => { + input = { + requiredItems: { cat1: 2, cat2: 3 }, + logicalOperation, + }; + earlyStopping = new StopAfterNItems(input); + }); + + it('instantiates with input parameters', () => testInstantiation(earlyStopping, input)); + it('validates input', () => + testLogicalOperationValidation(StopAfterNItems, { ...input, logicalOperation: 'invalid' as 'and' })); + it('updates internal state when new measurements are added', () => testInternalState(earlyStopping)); + + it('does not step when it has not seen required items', () => { + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + // Do not increment nItems for cat1 + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + // Do not increment nItems for cat1 + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + // Do not increment nItems for cat2 + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + // Increment nItems for cat1, but only use this update if + // logicalOperation is 'and'. Early stopping should still not be + // triggered. + nItems: 2, + seMeasurement: 0.5, + } as Cat, + cat2: { + // Do not increment nItems for cat2 + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(false); + + if (earlyStopping.logicalOperation === 'and') { + earlyStopping.update(updates[3]); + expect(earlyStopping.earlyStop).toBe(false); + } + }); + + it('stops when it has seen required items', () => { + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + // Do not increment nItems for cat1 + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + // Do not increment nItems for cat1 + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + // Cat2 reaches required items + nItems: 3, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + // Cat1 reaches required items + nItems: 2, + seMeasurement: 0.5, + } as Cat, + cat2: { + // Cat2 reaches required items + nItems: 3, + seMeasurement: 0.3, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(false); + + if (earlyStopping.logicalOperation === 'or') { + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(true); + } else { + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[3]); + expect(earlyStopping.earlyStop).toBe(true); + } + }); +}); + +describe('EarlyStopping with logicalOperation "only"', () => { + let earlyStopping: StopOnSEMeasurementPlateau; + let input: StopOnSEMeasurementPlateauInput; + + beforeEach(() => { + input = { + patience: { cat1: 2, cat2: 3 }, + tolerance: { cat1: 0.01, cat2: 0.02 }, + logicalOperation: 'only', + }; + earlyStopping = new StopOnSEMeasurementPlateau(input); + }); + + it('throws an error if catToSelect is not provided when logicalOperation is "only"', () => { + expect(() => { + earlyStopping.update({ cat1: { nItems: 1, seMeasurement: 0.5 } as any }, undefined); + }).toThrowError('Must provide a cat to select for "only" stopping condition'); + }); + + it('evaluates the stopping condition when catToSelect is in evaluationCats', () => { + // Add updates to make sure cat1 is included in evaluationCats and has some measurements + earlyStopping.update({ cat1: { nItems: 1, seMeasurement: 0.5 } as any }, 'cat1'); + earlyStopping.update({ cat1: { nItems: 2, seMeasurement: 0.5 } as any }, 'cat1'); + + // Since 'cat1' is in evaluationCats, _earlyStop should be evaluated based on the stopping condition + expect(earlyStopping.earlyStop).toBe(true); // Should be true because seMeasurement has plateaued + }); + it('sets _earlyStop to false when catToSelect is not in evaluationCats', () => { + // Use 'cat3', which is not in the patience or tolerance maps (and thus not in evaluationCats) + earlyStopping.update({ cat3: { nItems: 1, seMeasurement: 0.5 } as any }, 'cat3'); + + // Since 'cat3' is not in evaluationCats, _earlyStop should be false + expect(earlyStopping.earlyStop).toBe(false); + }); +}); + +describe('StopIfSEMeasurementBelowThreshold with empty patience and tolerance', () => { + let earlyStopping: StopIfSEMeasurementBelowThreshold; + let input: StopIfSEMeasurementBelowThresholdInput; + + beforeEach(() => { + input = { + seMeasurementThreshold: { cat1: 0.03, cat2: 0.02 }, + logicalOperation: 'only', + }; + earlyStopping = new StopIfSEMeasurementBelowThreshold(input); + }); + + it('should handle updates correctly even with empty patience and tolerance', () => { + // Update the state with some measurements for cat2, where seMeasurement is below the threshold + earlyStopping.update({ cat2: { nItems: 1, seMeasurement: 0.01 } as any }, 'cat2'); + + // Since patience defaults to 1 and tolerance defaults to 0, early stopping should be triggered + expect(earlyStopping.earlyStop).toBe(true); + }); + + it('should not trigger early stopping when seMeasurement does not fall below the threshold', () => { + // Update the state with some measurements for cat1, where seMeasurement is above the threshold + earlyStopping.update({ cat1: { nItems: 1, seMeasurement: 0.05 } as any }, 'cat1'); + + // Early stopping should not be triggered because the seMeasurement is above the threshold + expect(earlyStopping.earlyStop).toBe(false); + }); +}); + +describe('StopIfSEMeasurementBelowThreshold with undefined seMeasurementThreshold for a category', () => { + let earlyStopping: StopIfSEMeasurementBelowThreshold; + let input: StopIfSEMeasurementBelowThresholdInput; + + beforeEach(() => { + input = { + seMeasurementThreshold: {}, // Empty object, meaning no thresholds are defined + patience: { cat1: 2 }, // Setting patience to 2 for cat1 + tolerance: { cat1: 0.01 }, // Small tolerance for cat1 + logicalOperation: 'only', + }; + earlyStopping = new StopIfSEMeasurementBelowThreshold(input); + }); + + it('should use a default seThreshold of 0 when seMeasurementThreshold is not defined for the category', () => { + // Update the state with measurements for cat1, ensuring to meet the patience requirement + earlyStopping.update({ cat1: { nItems: 1, seMeasurement: -0.005 } as any }, 'cat1'); + earlyStopping.update({ cat1: { nItems: 2, seMeasurement: -0.01 } as any }, 'cat1'); + + // Early stopping should now be triggered because the seMeasurement has been below the default threshold of 0 for the patience period + expect(earlyStopping.earlyStop).toBe(true); + }); +}); + +describe('StopOnSEMeasurementPlateau without tolerance provided', () => { + let earlyStopping: StopOnSEMeasurementPlateau; + let input: StopOnSEMeasurementPlateauInput; + + beforeEach(() => { + input = { + patience: { cat1: 2 }, + // No tolerance is provided, it should default to an empty object + logicalOperation: 'only', + }; + earlyStopping = new StopOnSEMeasurementPlateau(input); + }); + + it('should handle updates without triggering early stopping when no tolerance is provided', () => { + // Update with measurements for cat1 that are not exactly equal, simulating tolerance as undefined + earlyStopping.update({ cat1: { nItems: 1, seMeasurement: 0.5 } as any }, 'cat1'); + earlyStopping.update({ cat1: { nItems: 2, seMeasurement: 0.55 } as any }, 'cat1'); + + // Since tolerance is undefined, early stopping should not be triggered even if seMeasurements are slightly different + expect(earlyStopping.earlyStop).toBe(false); + }); +}); + +describe.each` + logicalOperation + ${'and'} + ${'or'} +`("StopIfSEMeasurementBelowThreshold (with logicalOperation='$logicalOperation'", ({ logicalOperation }) => { + let earlyStopping: StopIfSEMeasurementBelowThreshold; + let input: StopIfSEMeasurementBelowThresholdInput; + + beforeEach(() => { + input = { + patience: { cat1: 1, cat2: 3 }, + tolerance: { cat1: 0.01, cat2: 0.02 }, + seMeasurementThreshold: { cat1: 0.03, cat2: 0.02 }, + logicalOperation, + }; + earlyStopping = new StopIfSEMeasurementBelowThreshold(input); + }); + + it('instantiates with input parameters', () => testInstantiation(earlyStopping, input)); + it('validates input', () => + testLogicalOperationValidation(StopIfSEMeasurementBelowThreshold, { + ...input, + logicalOperation: 'invalid' as 'and', + })); + it('updates internal state when new measurements are added', () => testInternalState(earlyStopping)); + + it('stops when the seMeasurement has fallen below a threshold', () => { + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 2, + seMeasurement: 0.02, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.02, + } as Cat, + }, + { + cat1: { + nItems: 3, + seMeasurement: 0.02, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.02, + } as Cat, + }, + { + cat1: { + nItems: 4, + seMeasurement: 0.02, + } as Cat, + cat2: { + nItems: 4, + seMeasurement: 0.02, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); + expect(earlyStopping.earlyStop).toBe(false); + + if (earlyStopping.logicalOperation === 'or') { + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(true); + } else { + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[3]); + expect(earlyStopping.earlyStop).toBe(true); + } + }); + + it('does not stop when the seMeasurement is above threshold', () => { + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 2, + seMeasurement: 0.1, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 3, + seMeasurement: 0.1, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 4, + seMeasurement: 0.1, + } as Cat, + cat2: { + nItems: 4, + seMeasurement: 0.3, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[3]); + expect(earlyStopping.earlyStop).toBe(false); + }); + + it('waits for `patience` items to monitor the seMeasurement plateau', () => { + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 2, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.01, + } as Cat, + }, + { + cat1: { + nItems: 3, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.01, + } as Cat, + }, + { + cat1: { + nItems: 4, + seMeasurement: 0.5, + } as Cat, + // Cat2 should trigger when logicalOperation is 'or' + cat2: { + nItems: 4, + seMeasurement: 0.01, + } as Cat, + }, + { + // Cat1 should trigger when logicalOperation is 'and' + // Cat2 criterion was satisfied after last update + cat1: { + nItems: 5, + seMeasurement: 0.01, + } as Cat, + cat2: { + nItems: 5, + seMeasurement: 0.01, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(false); + + if (earlyStopping.logicalOperation === 'or') { + earlyStopping.update(updates[3]); + expect(earlyStopping.earlyStop).toBe(true); + } else { + earlyStopping.update(updates[3]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[4]); + expect(earlyStopping.earlyStop).toBe(true); + } + }); + + it('triggers early stopping when within tolerance', () => { + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 10, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.4, + } as Cat, + }, + { + cat1: { + nItems: 2, + seMeasurement: 1, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.02, + } as Cat, + }, + { + cat1: { + nItems: 3, + seMeasurement: 0.0001, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.04, + } as Cat, + }, + { + cat1: { + nItems: 4, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 4, + seMeasurement: 0.01, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(false); + + if (earlyStopping.logicalOperation === 'or') { + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(true); + earlyStopping.update(updates[3]); + expect(earlyStopping.earlyStop).toBe(true); + } else { + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(false); + earlyStopping.update(updates[3]); + expect(earlyStopping.earlyStop).toBe(false); + } + }); +}); diff --git a/src/__tests__/utils.test.ts b/src/__tests__/utils.test.ts index d8392b5..8895e8e 100644 --- a/src/__tests__/utils.test.ts +++ b/src/__tests__/utils.test.ts @@ -2,40 +2,56 @@ import { itemResponseFunction, fisherInformation, findClosest } from '../utils'; describe('itemResponseFunction', () => { it('correctly calculates the probability', () => { - expect(0.7234).toBeCloseTo(itemResponseFunction(0, { a: 1, b: -0.3, c: 0.35, d: 1 }), 2); - - expect(0.5).toBeCloseTo(itemResponseFunction(0, { a: 1, b: 0, c: 0, d: 1 }), 2); - - expect(0.625).toBeCloseTo(itemResponseFunction(0, { a: 0.5, b: 0, c: 0.25, d: 1 }), 2); + expect(itemResponseFunction(0, { a: 1, b: -0.3, c: 0.35, d: 1 })).toBeCloseTo(0.7234, 2); + expect(itemResponseFunction(0, { a: 1, b: 0, c: 0, d: 1 })).toBeCloseTo(0.5, 2); + expect(itemResponseFunction(0, { a: 0.5, b: 0, c: 0.25, d: 1 })).toBeCloseTo(0.625, 2); }); }); describe('fisherInformation', () => { it('correctly calculates the information', () => { - expect(0.206).toBeCloseTo(fisherInformation(0, { a: 1.53, b: -0.5, c: 0.5, d: 1 }), 2); - - expect(0.1401).toBeCloseTo(fisherInformation(2.35, { a: 1, b: 2, c: 0.3, d: 1 }), 2); + expect(fisherInformation(0, { a: 1.53, b: -0.5, c: 0.5, d: 1 })).toBeCloseTo(0.206, 2); + expect(fisherInformation(2.35, { a: 1, b: 2, c: 0.3, d: 1 })).toBeCloseTo(0.1401, 2); }); }); describe('findClosest', () => { + const stimuli = [ + { difficulty: 1, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 4, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 10, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 11, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + ]; + it('correctly selects the first item if appropriate', () => { - expect(0).toBe(findClosest([{ difficulty: 1 }, { difficulty: 4 }, { difficulty: 10 }, { difficulty: 11 }], 0)); + expect(findClosest(stimuli, 0)).toBe(0); }); + it('correctly selects the last item if appropriate', () => { - expect(3).toBe(findClosest([{ difficulty: 1 }, { difficulty: 4 }, { difficulty: 10 }, { difficulty: 11 }], 1000)); + expect(findClosest(stimuli, 1000)).toBe(3); }); + it('correctly selects a middle item if it equals exactly', () => { - expect(2).toBe(findClosest([{ difficulty: 1 }, { difficulty: 4 }, { difficulty: 10 }, { difficulty: 11 }], 10)); + expect(findClosest(stimuli, 10)).toBe(2); }); + it('correctly selects the one item closest to the target if less than', () => { - expect(1).toBe( - findClosest([{ difficulty: 1.1 }, { difficulty: 4.2 }, { difficulty: 10.3 }, { difficulty: 11.4 }], 5.1), - ); + const stimuliWithDecimal = [ + { difficulty: 1.1, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 4.2, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 10.3, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 11.4, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + ]; + expect(findClosest(stimuliWithDecimal, 5.1)).toBe(1); }); + it('correctly selects the one item closest to the target if greater than', () => { - expect(2).toBe( - findClosest([{ difficulty: 1.1 }, { difficulty: 4.2 }, { difficulty: 10.3 }, { difficulty: 11.4 }], 9.1), - ); + const stimuliWithDecimal = [ + { difficulty: 1.1, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 4.2, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 10.3, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 11.4, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + ]; + expect(findClosest(stimuliWithDecimal, 9.1)).toBe(2); }); }); diff --git a/src/cat.ts b/src/cat.ts new file mode 100644 index 0000000..5344f26 --- /dev/null +++ b/src/cat.ts @@ -0,0 +1,319 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ +import { minimize_Powell } from 'optimization-js'; +import { Stimulus, Zeta } from './type'; +import { itemResponseFunction, fisherInformation, normal, findClosest } from './utils'; +import { validateZetaParams, fillZetaDefaults } from './corpus'; +import seedrandom from 'seedrandom'; +import _clamp from 'lodash/clamp'; +import _cloneDeep from 'lodash/cloneDeep'; + +const abilityPrior = normal(); + +export interface CatInput { + method?: string; + itemSelect?: string; + nStartItems?: number; + startSelect?: string; + theta?: number; + minTheta?: number; + maxTheta?: number; + prior?: number[][]; + randomSeed?: string | null; +} + +export class Cat { + public method: string; + public itemSelect: string; + public minTheta: number; + public maxTheta: number; + public prior: number[][]; + private readonly _zetas: Zeta[]; + private readonly _resps: (0 | 1)[]; + private _theta: number; + private _seMeasurement: number; + public nStartItems: number; + public startSelect: string; + private readonly _rng: ReturnType; + + /** + * Create a Cat object. This expects an single object parameter with the following keys + * @param {{method: string, itemSelect: string, nStartItems: number, startSelect:string, theta: number, minTheta: number, maxTheta: number, prior: number[][]}=} destructuredParam + * method: ability estimator, e.g. MLE or EAP, default = 'MLE' + * itemSelect: the method of item selection, e.g. "MFI", "random", "closest", default method = 'MFI' + * nStartItems: first n trials to keep non-adaptive selection + * startSelect: rule to select first n trials + * theta: initial theta estimate + * minTheta: lower bound of theta + * maxTheta: higher bound of theta + * prior: the prior distribution + * randomSeed: set a random seed to trace the simulation + */ + + constructor({ + method = 'MLE', + itemSelect = 'MFI', + nStartItems = 0, + startSelect = 'middle', + theta = 0, + minTheta = -6, + maxTheta = 6, + prior = abilityPrior, + randomSeed = null, + }: CatInput = {}) { + this.method = Cat.validateMethod(method); + + this.itemSelect = Cat.validateItemSelect(itemSelect); + + this.startSelect = Cat.validateStartSelect(startSelect); + + this.minTheta = minTheta; + this.maxTheta = maxTheta; + this.prior = prior; + this._zetas = []; + this._resps = []; + this._theta = theta; + this._seMeasurement = Number.MAX_VALUE; + this.nStartItems = nStartItems; + this._rng = randomSeed === null ? seedrandom() : seedrandom(randomSeed); + } + + public get theta() { + return this._theta; + } + + public get seMeasurement() { + return this._seMeasurement; + } + + /** + * Return the number of items that have been observed so far. + */ + public get nItems() { + return this._resps.length; + } + + public get resps() { + return this._resps; + } + + public get zetas() { + return this._zetas; + } + + private static validateMethod(method: string) { + const lowerMethod = method.toLowerCase(); + const validMethods: Array = ['mle', 'eap']; // TO DO: add staircase + if (!validMethods.includes(lowerMethod)) { + throw new Error('The abilityEstimator you provided is not in the list of valid methods'); + } + return lowerMethod; + } + + private static validateItemSelect(itemSelect: string) { + const lowerItemSelect = itemSelect.toLowerCase(); + const validItemSelect: Array = ['mfi', 'random', 'closest', 'fixed']; + if (!validItemSelect.includes(lowerItemSelect)) { + throw new Error('The itemSelector you provided is not in the list of valid methods'); + } + return lowerItemSelect; + } + + private static validateStartSelect(startSelect: string) { + const lowerStartSelect = startSelect.toLowerCase(); + const validStartSelect: Array = ['random', 'middle', 'fixed']; // TO DO: add staircase + if (!validStartSelect.includes(lowerStartSelect)) { + throw new Error('The startSelect you provided is not in the list of valid methods'); + } + return lowerStartSelect; + } + + /** + * use previous response patterns and item params to calculate the estimate ability based on a defined method + * @param zeta - last item param + * @param answer - last response pattern + * @param method + */ + public updateAbilityEstimate(zeta: Zeta | Zeta[], answer: (0 | 1) | (0 | 1)[], method: string = this.method) { + method = Cat.validateMethod(method); + + zeta = Array.isArray(zeta) ? zeta : [zeta]; + answer = Array.isArray(answer) ? answer : [answer]; + + zeta.forEach((z) => validateZetaParams(z, true)); + + if (zeta.length !== answer.length) { + throw new Error('Unmatched length between answers and item params'); + } + this._zetas.push(...zeta); + this._resps.push(...answer); + + if (method === 'eap') { + this._theta = this.estimateAbilityEAP(); + } else if (method === 'mle') { + this._theta = this.estimateAbilityMLE(); + } + this.calculateSE(); + } + + private estimateAbilityEAP() { + let num = 0; + let nf = 0; + this.prior.forEach(([theta, probability]) => { + const like = this.likelihood(theta); + num += theta * like * probability; + nf += like * probability; + }); + + return num / nf; + } + + private estimateAbilityMLE() { + const theta0 = [0]; + const solution = minimize_Powell(this.negLikelihood.bind(this), theta0); + const theta = solution.argument[0]; + return _clamp(theta, this.minTheta, this.maxTheta); + } + + private negLikelihood(thetaArray: Array) { + return -this.likelihood(thetaArray[0]); + } + + private likelihood(theta: number) { + return this._zetas.reduce((acc, zeta, i) => { + const irf = itemResponseFunction(theta, zeta); + return this._resps[i] === 1 ? acc + Math.log(irf) : acc + Math.log(1 - irf); + }, 1); + } + + /** + * calculate the standard error of ability estimation + */ + private calculateSE() { + const sum = this._zetas.reduce((previousValue, zeta) => previousValue + fisherInformation(this._theta, zeta), 0); + this._seMeasurement = 1 / Math.sqrt(sum); + } + + /** + * find the next available item from an input array of stimuli based on a selection method + * + * remainingStimuli is sorted by fisher information to reduce the computation complexity for future item selection + * @param stimuli - an array of stimulus + * @param itemSelect - the item selection method + * @param deepCopy - default deepCopy = true + * @returns {nextStimulus: Stimulus, remainingStimuli: Array} + */ + public findNextItem(stimuli: Stimulus[], itemSelect: string = this.itemSelect, deepCopy = true) { + let arr: Array; + let selector = Cat.validateItemSelect(itemSelect); + if (deepCopy) { + arr = _cloneDeep(stimuli); + } else { + arr = stimuli; + } + + arr = arr.map((stim) => fillZetaDefaults(stim, 'semantic')); + + if (this.nItems < this.nStartItems) { + selector = this.startSelect; + } + if (selector !== 'mfi' && selector !== 'fixed') { + // for mfi, we sort the arr by fisher information in the private function to select the best item, + // and then sort by difficulty to return the remainingStimuli + // for fixed, we want to keep the corpus order as input + arr.sort((a: Stimulus, b: Stimulus) => a.difficulty! - b.difficulty!); + } + + if (selector === 'middle') { + // middle will only be used in startSelect + return this.selectorMiddle(arr); + } else if (selector === 'closest') { + return this.selectorClosest(arr); + } else if (selector === 'random') { + return this.selectorRandom(arr); + } else if (selector === 'fixed') { + return this.selectorFixed(arr); + } else { + return this.selectorMFI(arr); + } + } + + private selectorMFI(inputStimuli: Stimulus[]) { + const stimuli = inputStimuli.map((stim) => fillZetaDefaults(stim, 'semantic')); + const stimuliAddFisher = stimuli.map((element: Stimulus) => ({ + fisherInformation: fisherInformation(this._theta, fillZetaDefaults(element, 'symbolic')), + ...element, + })); + + stimuliAddFisher.sort((a, b) => b.fisherInformation - a.fisherInformation); + stimuliAddFisher.forEach((stimulus: Stimulus) => { + delete stimulus['fisherInformation']; + }); + return { + nextStimulus: stimuliAddFisher[0], + remainingStimuli: stimuliAddFisher.slice(1).sort((a: Stimulus, b: Stimulus) => a.difficulty! - b.difficulty!), + }; + } + + private selectorMiddle(arr: Stimulus[]) { + let index: number; + index = Math.floor(arr.length / 2); + + if (arr.length >= this.nStartItems) { + index += this.randomInteger(-Math.floor(this.nStartItems / 2), Math.floor(this.nStartItems / 2)); + } + + const nextItem = arr[index]; + arr.splice(index, 1); + return { + nextStimulus: nextItem, + remainingStimuli: arr, + }; + } + + private selectorClosest(arr: Stimulus[]) { + //findClosest requires arr is sorted by difficulty + const index = findClosest(arr, this._theta + 0.481); + const nextItem = arr[index]; + arr.splice(index, 1); + return { + nextStimulus: nextItem, + remainingStimuli: arr, + }; + } + + private selectorRandom(arr: Stimulus[]) { + const index = this.randomInteger(0, arr.length - 1); + const nextItem = arr.splice(index, 1)[0]; + return { + nextStimulus: nextItem, + remainingStimuli: arr, + }; + } + + /** + * Picks the next item in line from the given list of stimuli. + * It grabs the first item from the list, removes it, and then returns it along with the rest of the list. + * + * @param arr - The list of stimuli to choose from. + * @returns {Object} - An object with the next item and the updated list. + * @returns {Stimulus} return.nextStimulus - The item that was picked from the list. + * @returns {Stimulus[]} return.remainingStimuli - The list of what's left after picking the item. + */ + private selectorFixed(arr: Stimulus[]) { + const nextItem = arr.shift(); + return { + nextStimulus: nextItem, + remainingStimuli: arr, + }; + } + + /** + * return a random integer between min and max + * @param min - The minimum of the random number range (include) + * @param max - The maximum of the random number range (include) + * @returns {number} - random integer within the range + */ + private randomInteger(min: number, max: number) { + return Math.floor(this._rng() * (max - min + 1)) + min; + } +} diff --git a/src/clowder.ts b/src/clowder.ts new file mode 100644 index 0000000..6117ff1 --- /dev/null +++ b/src/clowder.ts @@ -0,0 +1,393 @@ +import { Cat, CatInput } from './cat'; +import { CatMap, MultiZetaStimulus, Stimulus, Zeta, ZetaCatMap } from './type'; +import { filterItemsByCatParameterAvailability, checkNoDuplicateCatNames } from './corpus'; +import _cloneDeep from 'lodash/cloneDeep'; +import _differenceWith from 'lodash/differenceWith'; +import _isEqual from 'lodash/isEqual'; +import _mapValues from 'lodash/mapValues'; +import _omit from 'lodash/omit'; +import _unzip from 'lodash/unzip'; +import _zip from 'lodash/zip'; +import seedrandom from 'seedrandom'; +import { EarlyStopping } from './stopping'; + +export interface ClowderInput { + /** + * An object containing Cat configurations for each Cat instance. + * Keys correspond to Cat names, while values correspond to Cat configurations. + */ + cats: CatMap; + /** + * An object containing arrays of stimuli for each corpus. + */ + corpus: MultiZetaStimulus[]; + /** + * A random seed for reproducibility. If not provided, a random seed will be generated. + */ + randomSeed?: string | null; + /** + * An optional EarlyStopping instance to use for early stopping. + */ + earlyStopping?: EarlyStopping; +} + +/** + * The Clowder class is responsible for managing a collection of Cat instances + * along with a corpus of stimuli. It maintains a list of named Cat instances + * and a corpus where each item in the coprpus may have IRT parameters + * corresponding to each named Cat. Clowder provides methods for updating the + * ability estimates of each of its Cats, and selecting the next item to present + * to the participant. + */ +export class Clowder { + private _cats: CatMap; + private _corpus: MultiZetaStimulus[]; + private _remainingItems: MultiZetaStimulus[]; + private _seenItems: Stimulus[]; + private _earlyStopping?: EarlyStopping; + private readonly _rng: ReturnType; + private _stoppingReason: string | null; + + /** + * Create a Clowder object. + * + * @param {ClowderInput} input - An object containing arrays of Cat configurations and corpora. + * @param {CatMap} input.cats - An object containing Cat configurations for each Cat instance. + * @param {MultiZetaStimulus[]} input.corpus - An array of stimuli representing each corpus. + * + * @throws {Error} - Throws an error if any item in the corpus has duplicated IRT parameters for any Cat name. + */ + constructor({ cats, corpus, randomSeed = null, earlyStopping }: ClowderInput) { + // TODO: Add some imput validation to both the cats and the corpus to make sure that "unvalidated" is not used as a cat name. + // If so, throw an error saying that "unvalidated" is a reserved name and may not be used. + // TODO: Also add a test of this behavior. + this._cats = { + ..._mapValues(cats, (catInput) => new Cat(catInput)), + unvalidated: new Cat({ itemSelect: 'random', randomSeed }), // Add 'unvalidated' cat + }; + this._seenItems = []; + checkNoDuplicateCatNames(corpus); + this._corpus = corpus; + this._remainingItems = _cloneDeep(corpus); + this._rng = randomSeed === null ? seedrandom() : seedrandom(randomSeed); + this._earlyStopping = earlyStopping; + this._stoppingReason = null; + } + + /** + * Validate the provided Cat name against the existing Cat instances. + * Throw an error if the Cat name is not found. + * + * @param {string} catName - The name of the Cat instance to validate. + * @param {boolean} allowUnvalidated - Whether to allow the reserved 'unvalidated' name. + * + * @throws {Error} - Throws an error if the provided Cat name is not found among the existing Cat instances. + */ + private _validateCatName(catName: string, allowUnvalidated = false): void { + const allowedCats = allowUnvalidated ? this._cats : this.cats; + if (!Object.prototype.hasOwnProperty.call(allowedCats, catName)) { + throw new Error(`Invalid Cat name. Expected one of ${Object.keys(allowedCats).join(', ')}. Received ${catName}.`); + } + } + + /** + * The named Cat instances that this Clowder manages. + */ + public get cats() { + return _omit(this._cats, ['unvalidated']); + } + + /** + * The corpus that was provided to this Clowder when it was created. + */ + public get corpus() { + return this._corpus; + } + + /** + * The subset of the input corpus that this Clowder has not yet "seen". + */ + public get remainingItems() { + return this._remainingItems; + } + + /** + * The subset of the input corpus that this Clowder has "seen" so far. + */ + public get seenItems() { + return this._seenItems; + } + + /** + * The theta estimates for each Cat instance. + */ + public get theta() { + return _mapValues(this.cats, (cat) => cat.theta); + } + + /** + * The standard error of measurement estimates for each Cat instance. + */ + public get seMeasurement() { + return _mapValues(this.cats, (cat) => cat.seMeasurement); + } + + /** + * The number of items presented to each Cat instance. + */ + public get nItems() { + return _mapValues(this.cats, (cat) => cat.nItems); + } + + /** + * The responses received by each Cat instance. + */ + public get resps() { + return _mapValues(this.cats, (cat) => cat.resps); + } + + /** + * The zeta (item parameters) received by each Cat instance. + */ + public get zetas() { + return _mapValues(this.cats, (cat) => cat.zetas); + } + + /** + * The early stopping condition in the Clowder configuration. + */ + public get earlyStopping() { + return this._earlyStopping; + } + + /** + * The stopping reason in the Clowder configuration. + */ + public get stoppingReason() { + return this._stoppingReason; + } + + /** + * Updates the ability estimates for the specified Cat instances. + * + * @param {string[]} catNames - The names of the Cat instances to update. + * @param {Zeta | Zeta[]} zeta - The item parameter(s) (zeta) for the given stimuli. + * @param {(0 | 1) | (0 | 1)[]} answer - The corresponding answer(s) (0 or 1) for the given stimuli. + * @param {string} [method] - Optional method for updating ability estimates. If none is provided, it will use the default method for each Cat instance. + * + * @throws {Error} If any `catName` is not found among the existing Cat instances. + */ + public updateAbilityEstimates(catNames: string[], zeta: Zeta | Zeta[], answer: (0 | 1) | (0 | 1)[], method?: string) { + catNames.forEach((catName) => { + this._validateCatName(catName, false); + }); + for (const catName of catNames) { + this.cats[catName].updateAbilityEstimate(zeta, answer, method); + } + } + + /** + * Update the ability estimates for the specified `catsToUpdate` and select the next stimulus for the `catToSelect`. + * This function processes previous items and answers, updates internal state, and selects the next stimulus + * based on the remaining stimuli and `catToSelect`. + * + * @param {Object} input - The parameters for updating the Cat instance and selecting the next stimulus. + * @param {string} input.catToSelect - The Cat instance to use for selecting the next stimulus. + * @param {string | string[]} [input.catsToUpdate=[]] - A single Cat or array of Cats for which to update ability estimates. + * @param {Stimulus[]} [input.items=[]] - An array of previously presented stimuli. + * @param {(0 | 1) | (0 | 1)[]} [input.answers=[]] - An array of answers (0 or 1) corresponding to `items`. + * @param {string} [input.method] - Optional method for updating ability estimates (if applicable). + * @param {string} [input.itemSelect] - Optional item selection method (if applicable). + * @param {boolean} [input.randomlySelectUnvalidated=false] - Optional flag indicating whether to randomly select an unvalidated item for `catToSelect`. + * + * @returns {Stimulus | undefined} - The next stimulus to present, or `undefined` if no further validated stimuli are available. + * + * @throws {Error} If `items` and `answers` lengths do not match. + * @throws {Error} If any `items` are not found in the Clowder's corpora (validated or unvalidated). + * + * The function operates in several steps: + * 1. Validate: + * a. Validates the `catToSelect` and `catsToUpdate`. + * b. Ensures `items` and `answers` arrays are properly formatted. + * 2. Update: + * a. Updates the internal list of seen items. + * b. Updates the ability estimates for the `catsToUpdate`. + * 3. Select: + * a. Selects the next item using `catToSelect`, considering only remaining items that are valid for that cat. + * b. If desired, randomly selects an unvalidated item for catToSelect. + */ + public updateCatAndGetNextItem({ + catToSelect, + catsToUpdate = [], + items = [], + answers = [], + method, + itemSelect, + randomlySelectUnvalidated = false, + returnUndefinedOnExhaustion = true, + }: { + catToSelect: string; + catsToUpdate?: string | string[]; + items?: MultiZetaStimulus | MultiZetaStimulus[]; + answers?: (0 | 1) | (0 | 1)[]; + method?: string; + itemSelect?: string; + randomlySelectUnvalidated?: boolean; + returnUndefinedOnExhaustion?: boolean; // New parameter type + }): Stimulus | undefined { + // +----------+ + // ----------| Update |----------| + // +----------+ + this._validateCatName(catToSelect, true); + catsToUpdate = Array.isArray(catsToUpdate) ? catsToUpdate : [catsToUpdate]; + catsToUpdate.forEach((cat) => { + this._validateCatName(cat, false); + }); + + // Convert items and answers to arrays + items = Array.isArray(items) ? items : [items]; + answers = Array.isArray(answers) ? answers : [answers]; + + // Ensure that the lengths of items and answers match + if (items.length !== answers.length) { + throw new Error('Previous items and answers must have the same length.'); + } + + // +----------+ + // ----------| Update |----------| + // +----------+ + + // Update the seenItems with the provided previous items + this._seenItems.push(...items); + + // Remove the provided previous items from the remainingItems + this._remainingItems = _differenceWith(this._remainingItems, items, _isEqual); + + // Create a new zip array of items and answers. This will be useful in + // filtering operations below. It ensures that items and their corresponding + // answers "stay together." + const itemsAndAnswers = _zip(items, answers) as [Stimulus, 0 | 1][]; + + // Update the ability estimate for all validated cats + for (const catName of catsToUpdate) { + const itemsAndAnswersForCat = itemsAndAnswers.filter(([stim]) => + // We are dealing with a single item in this function. This single item + // has an array of zeta parameters for a bunch of different Cats. We + // need to determine if `catName` is present in that list. So we first + // reduce the zetas to get all of the applicabe cat names. + // Now that we have the subset of items that can apply to this cat, + // retrieve only the item parameters that apply to this cat. + stim.zetas.some((zeta: ZetaCatMap) => zeta.cats.includes(catName)), + ); + + if (itemsAndAnswersForCat.length > 0) { + const zetasAndAnswersForCat = itemsAndAnswersForCat + .map(([stim, _answer]) => { + const zetaForCat: ZetaCatMap | undefined = stim.zetas.find((zeta: ZetaCatMap) => + zeta.cats.includes(catName), + ); + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + return [zetaForCat!.zeta, _answer]; // Optional chaining in case zetaForCat is undefined + }) + .filter(([zeta]) => zeta !== undefined); // Filter out undefined zeta values + + // Unzip the zetas and answers, making sure the zetas array contains only Zeta types + const [zetas, answers] = _unzip(zetasAndAnswersForCat) as [Zeta[], (0 | 1)[]]; + + // Now call updateAbilityEstimates for this cat + this.updateAbilityEstimates([catName], zetas, answers, method); + } + } + + if (this._earlyStopping) { + this._earlyStopping.update(this.cats, catToSelect); + if (this._earlyStopping.earlyStop) { + this._stoppingReason = 'Early stopping'; + return undefined; + } + } + + // Handle the 'unvalidated' cat selection + // +----------+ + // ----------| Select |----------| + // +----------+ + + // We inspect the remaining items and find ones that have zeta parameters for `catToSelect` + const { available, missing } = filterItemsByCatParameterAvailability(this._remainingItems, catToSelect); + + // Handle the 'unvalidated' cat selection + if (catToSelect === 'unvalidated') { + const unvalidatedRemainingItems = this._remainingItems.filter( + (stim) => !stim.zetas.some((zeta: ZetaCatMap) => zeta.cats.length > 0), + ); + + if (unvalidatedRemainingItems.length === 0) { + // If returnUndefinedOnExhaustion is false, return an item from 'missing' + if (!returnUndefinedOnExhaustion && missing.length > 0) { + const randInt = Math.floor(this._rng() * missing.length); + return missing[randInt]; + } + this._stoppingReason = 'No unvalidated items remaining'; + return undefined; + } else { + const randInt = Math.floor(this._rng() * unvalidatedRemainingItems.length); + return unvalidatedRemainingItems[randInt]; + } + } + + // The cat expects an array of Stimulus objects, with the zeta parameters + // spread at the top-level of each Stimulus object. So we need to convert + // the MultiZetaStimulus array to an array of Stimulus objects. + const availableCatInput = available.map((item) => { + const zetasForCat = item.zetas.find((zeta) => zeta.cats.includes(catToSelect)); + return { + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + ...zetasForCat!.zeta, + ...item, + }; + }); + + // Use the catForSelect to determine the next stimulus + const cat = this.cats[catToSelect]; + const { nextStimulus } = cat.findNextItem(availableCatInput, itemSelect); + const nextStimulusWithoutZeta = _omit(nextStimulus, [ + 'a', + 'b', + 'c', + 'd', + 'discrimination', + 'difficulty', + 'guessing', + 'slipping', + ]); + // Again `nextStimulus` will be a Stimulus object, or `undefined` if no further validated stimuli are available. + // We need to convert the Stimulus object back to a MultiZetaStimulus object to return to the user. + const returnStimulus: MultiZetaStimulus | undefined = available.find((stim) => + _isEqual( + _omit(stim, ['a', 'b', 'c', 'd', 'discrimination', 'difficulty', 'guessing', 'slipping']), + nextStimulusWithoutZeta, + ), + ); + + // Determine behavior based on returnUndefinedOnExhaustion + if (available.length === 0) { + // If returnUndefinedOnExhaustion is true and no validated items remain for the specified catToSelect, return undefined. + if (returnUndefinedOnExhaustion) { + this._stoppingReason = 'No validated items remaining for specified catToSelect'; + return undefined; // Return undefined if no validated items remain + } else { + // If returnUndefinedOnExhaustion is false, proceed with the fallback mechanism to select an item from other available categories. + return missing[Math.floor(this._rng() * missing.length)]; + } + } else if (missing.length === 0 || !randomlySelectUnvalidated) { + return returnStimulus; // Return validated item if available + } else { + // Randomly decide whether to return a validated or unvalidated item + const random = Math.random(); + const numRemaining = { available: available.length, missing: missing.length }; + return random < numRemaining.missing / (numRemaining.available + numRemaining.missing) + ? missing[Math.floor(this._rng() * missing.length)] + : returnStimulus; + } + } +} diff --git a/src/corpus.ts b/src/corpus.ts new file mode 100644 index 0000000..94102f5 --- /dev/null +++ b/src/corpus.ts @@ -0,0 +1,306 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ +import { MultiZetaStimulus, Stimulus, Zeta } from './type'; +import _flatten from 'lodash/flatten'; +import _invert from 'lodash/invert'; +import _isEmpty from 'lodash/isEmpty'; +import _mapKeys from 'lodash/mapKeys'; +import _union from 'lodash/union'; +import _uniq from 'lodash/uniq'; +import _omit from 'lodash/omit'; + +/** + * A constant map from the symbolic item parameter names to their semantic + * counterparts. + */ +export const ZETA_KEY_MAP = { + a: 'discrimination', + b: 'difficulty', + c: 'guessing', + d: 'slipping', +}; + +/** + * Return default item parameters (i.e., zeta) + * + * @param {'symbolic' | 'semantic'} desiredFormat - The desired format for the output zeta object. + * @returns {Zeta} the default zeta object in the specified format. + */ +export const defaultZeta = (desiredFormat: 'symbolic' | 'semantic' = 'symbolic'): Zeta => { + const defaultZeta: Zeta = { + a: 1, + b: 0, + c: 0, + d: 1, + }; + + return convertZeta(defaultZeta, desiredFormat); +}; + +/** + * Validates the item (a.k.a. zeta) parameters, prohibiting redundant keys and + * optionally requiring all parameters. + * + * @param {Zeta} zeta - The zeta parameters to validate. + * @param {boolean} requireAll - If `true`, ensures that all required keys are present. Default is `false`. + * + * @throws {Error} Will throw an error if any of the validation rules are violated. + */ +export const validateZetaParams = (zeta: Zeta, requireAll = false): void => { + if (zeta.a !== undefined && zeta.discrimination !== undefined) { + throw new Error('This item has both an `a` key and `discrimination` key. Please provide only one.'); + } + + if (zeta.b !== undefined && zeta.difficulty !== undefined) { + throw new Error('This item has both a `b` key and `difficulty` key. Please provide only one.'); + } + + if (zeta.c !== undefined && zeta.guessing !== undefined) { + throw new Error('This item has both a `c` key and `guessing` key. Please provide only one.'); + } + + if (zeta.d !== undefined && zeta.slipping !== undefined) { + throw new Error('This item has both a `d` key and `slipping` key. Please provide only one.'); + } + + if (requireAll) { + if (zeta.a === undefined && zeta.discrimination === undefined) { + throw new Error('This item is missing the key `a` or `discrimination`.'); + } + + if (zeta.b === undefined && zeta.difficulty === undefined) { + throw new Error('This item is missing the key `b` or `difficulty`.'); + } + + if (zeta.c === undefined && zeta.guessing === undefined) { + throw new Error('This item is missing the key `c` or `guessing`.'); + } + + if (zeta.d === undefined && zeta.slipping === undefined) { + throw new Error('This item is missing the key `d` or `slipping`.'); + } + } +}; + +/** + * Fills in default zeta parameters for any missing keys in the provided zeta object. + * + * @remarks + * This function merges the provided zeta object with the default zeta object, converting + * the keys to the desired format if specified. If no desired format is provided, the + * keys will remain in their original format. + * + * @param {Zeta} zeta - The zeta parameters to fill in defaults for. + * @param {'symbolic' | 'semantic'} desiredFormat - The desired format for the output zeta object. Default is 'symbolic'. + * + * @returns A new zeta object with default values filled in for any missing keys, + * and converted to the desired format if specified. + */ +export const fillZetaDefaults = (zeta: Zeta, desiredFormat: 'symbolic' | 'semantic' = 'symbolic'): Zeta => { + return { + ...defaultZeta(desiredFormat), + ...convertZeta(zeta, desiredFormat), + }; +}; + +/** + * Converts zeta parameters between symbolic and semantic formats. + * + * @remarks + * This function takes a zeta object and a desired format as input. It converts + * the keys of the zeta object from their current format to the desired format. + * If the desired format is 'symbolic', the function maps the keys to their + * symbolic counterparts using the `ZETA_KEY_MAP`. If the desired format is + * 'semantic', the function maps the keys to their semantic counterparts using + * the inverse of `ZETA_KEY_MAP`. + * + * @param {Zeta} zeta - The zeta parameters to convert. + * @param {'symbolic' | 'semantic'} desiredFormat - The desired format for the output zeta object. Must be either 'symbolic' or 'semantic'. + * + * @throws {Error} - Will throw an error if the desired format is not 'symbolic' or 'semantic'. + * + * @returns {Zeta} A new zeta object with keys converted to the desired format. + */ +export const convertZeta = (zeta: Zeta, desiredFormat: 'symbolic' | 'semantic'): Zeta => { + if (!['symbolic', 'semantic'].includes(desiredFormat)) { + throw new Error(`Invalid desired format. Expected 'symbolic' or'semantic'. Received ${desiredFormat} instead.`); + } + + return _mapKeys(zeta, (value, key) => { + if (desiredFormat === 'symbolic') { + const inverseMap = _invert(ZETA_KEY_MAP); + if (key in inverseMap) { + return inverseMap[key]; + } else { + return key; + } + } else { + if (key in ZETA_KEY_MAP) { + return ZETA_KEY_MAP[key as keyof typeof ZETA_KEY_MAP]; + } else { + return key; + } + } + }); +}; + +/** + * Validates a corpus of multi-zeta stimuli to ensure that no cat names are + * duplicated. + * + * @remarks + * This function takes an array of `MultiZetaStimulus` objects, where each + * object represents an item containing item parameters (zetas) associated with + * different CAT models. The function checks for any duplicate cat names across + * each item's array of zeta values. It throws an error if any are found. + * + * @param {MultiZetaStimulus[]} corpus - An array of `MultiZetaStimulus` objects representing the corpora to validate. + * + * @throws {Error} - Throws an error if any duplicate cat names are found across the corpora. + */ +export const checkNoDuplicateCatNames = (corpus: MultiZetaStimulus[]): void => { + const zetaCatMapsArray = corpus.map((item) => item.zetas); + for (const zetaCatMaps of zetaCatMapsArray) { + const cats = zetaCatMaps.map(({ cats }) => cats); + + // Check to see if there are any duplicate names by comparing the union + // (which removed duplicates) to the flattened array. + const union = _union(...cats); + const flattened = _flatten(cats); + + if (union.length !== flattened.length) { + // If there are duplicates, remove the first occurence of each cat name in + // the union array from the flattened array. The remaining items in the + // flattened array should contain the duplicated cat names. + for (const cat of union) { + const idx = flattened.findIndex((c) => c === cat); + if (idx >= 0) { + flattened.splice(idx, 1); + } + } + + throw new Error(`The cat names ${_uniq(flattened).join(', ')} are present in multiple corpora.`); + } + } +}; + +/** + * Filters a list of multi-zeta stimuli based on the availability of model parameters for a specific CAT. + * + * This function takes an array of `MultiZetaStimulus` objects and a `catName` as input. It then filters + * the items based on whether the specified CAT model parameter is present in the item's zeta values. + * The function returns an object containing two arrays: `available` and `missing`. The `available` array + * contains items where the specified CAT model parameter is present, while the `missing` array contains + * items where the parameter is not present. + * + * @param {MultiZetaStimulus[]} items - An array of `MultiZetaStimulus` objects representing the stimuli to filter. + * @param {string} catName - The name of the CAT model parameter to check for. + * + * @returns An object with two arrays: `available` and `missing`. + * + * @example + * ```typescript + * const items: MultiZetaStimulus[] = [ + * { + * stimulus: 'Item 1', + * zetas: [ + * { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + * { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + * ], + * }, + * { + * stimulus: 'Item 2', + * zetas: [ + * { cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }, + * ], + * }, + * ]; + * + * const result = filterItemsByCatParameterAvailability(items, 'Model A'); + * console.log(result.available); + * // Output: [ + * // { + * // stimulus: 'Item 1', + * // zetas: [ + * // { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + * // { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + * // ], + * // }, + * // ] + * console.log(result.missing); + * // Output: [ + * // { + * // stimulus: 'Item 2', + * // zetas: [ + * // { cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }, + * // ], + * // }, + * // ] + * ``` + */ +export const filterItemsByCatParameterAvailability = (items: MultiZetaStimulus[], catName: string) => { + const paramsExist = items.filter((item) => item.zetas.some((zetaCatMap) => zetaCatMap.cats.includes(catName))); + const paramsMissing = items.filter((item) => !item.zetas.some((zetaCatMap) => zetaCatMap.cats.includes(catName))); + + return { + available: paramsExist, + missing: paramsMissing, + }; +}; + +/** + * Converts an array of Stimulus objects into an array of MultiZetaStimulus objects. + * The user specifies cat names and a delimiter to identify and group parameters. + * + * @param {Stimulus[]} items - An array of stimuli, where each stimulus contains parameters + * for different CAT instances. + * @param {string[]} catNames - A list of CAT names to be mapped to their corresponding zeta values. + * @param {string} delimiter - A delimiter used to separate CAT instance names from the parameter keys in the stimulus object. + * @param {'symbolic' | 'semantic'} itemParameterFormat - Defines the format to convert zeta values ('symbolic' or 'semantic'). + * @returns {MultiZetaStimulus[]} - An array of MultiZetaStimulus objects, each containing + * the cleaned stimulus and associated zeta values for each CAT instance. + * + * This function iterates through each stimulus, extracts parameters relevant to the specified + * CAT instances, converts them to the desired format, and returns a cleaned structure of stimuli + * with the associated zeta values. + */ +export const prepareClowderCorpus = ( + items: Stimulus[], + catNames: string[], + delimiter: '.' | string, + itemParameterFormat: 'symbolic' | 'semantic' = 'symbolic', +): MultiZetaStimulus[] => { + return items.map((item) => { + const zetas = catNames + .map((cat) => { + const zeta: Zeta = {}; + + // Extract parameters that match the category + Object.keys(item).forEach((key) => { + if (key.startsWith(cat + delimiter)) { + const paramKey = key.split(delimiter)[1]; + zeta[paramKey as keyof Zeta] = item[key]; + } + }); + + return { + cats: [cat], + zeta: convertZeta(zeta, itemParameterFormat), + }; + }) + .filter((zeta) => { + // Check if zeta has no `NA` values and is not empty + return !_isEmpty(zeta.zeta) && Object.values(zeta.zeta).every((value) => value !== 'NA'); + }); + + // Create the MultiZetaStimulus structure without the category keys + const cleanItem = _omit( + item, + Object.keys(item).filter((key) => catNames.some((cat) => key.startsWith(cat + delimiter))), + ); + + return { + ...cleanItem, + zetas, + }; + }); +}; diff --git a/src/index.ts b/src/index.ts index 83424c1..2571bd7 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,321 +1,9 @@ -import { minimize_Powell } from 'optimization-js'; -import { cloneDeep } from 'lodash'; -import { Stimulus, Zeta } from './type'; -import { itemResponseFunction, fisherInformation, normal, findClosest } from './utils'; -import seedrandom from 'seedrandom'; - -export const abilityPrior = normal(); - -export interface CatInput { - method?: string; - itemSelect?: string; - nStartItems?: number; - startSelect?: string; - theta?: number; - minTheta?: number; - maxTheta?: number; - prior?: number[][]; - randomSeed?: string | null; -} - -export class Cat { - public method: string; - public itemSelect: string; - public minTheta: number; - public maxTheta: number; - public prior: number[][]; - private readonly _zetas: Zeta[]; - private readonly _resps: (0 | 1)[]; - private _nItems: number; - private _theta: number; - private _seMeasurement: number; - public nStartItems: number; - public startSelect: string; - private readonly _rng: ReturnType; - - /** - * Create a Cat object. This expects an single object parameter with the following keys - * @param {{method: string, itemSelect: string, nStartItems: number, startSelect:string, theta: number, minTheta: number, maxTheta: number, prior: number[][]}=} destructuredParam - * method: ability estimator, e.g. MLE or EAP, default = 'MLE' - * itemSelect: the method of item selection, e.g. "MFI", "random", "closest", default method = 'MFI' - * nStartItems: first n trials to keep non-adaptive selection - * startSelect: rule to select first n trials - * theta: initial theta estimate - * minTheta: lower bound of theta - * maxTheta: higher bound of theta - * prior: the prior distribution - * randomSeed: set a random seed to trace the simulation - */ - - constructor({ - method = 'MLE', - itemSelect = 'MFI', - nStartItems = 0, - startSelect = 'middle', - theta = 0, - minTheta = -6, - maxTheta = 6, - prior = abilityPrior, - randomSeed = null, - }: CatInput = {}) { - this.method = Cat.validateMethod(method); - - this.itemSelect = Cat.validateItemSelect(itemSelect); - - this.startSelect = Cat.validateStartSelect(startSelect); - - this.minTheta = minTheta; - this.maxTheta = maxTheta; - this.prior = prior; - this._zetas = []; - this._resps = []; - this._theta = theta; - this._nItems = 0; - this._seMeasurement = Number.MAX_VALUE; - this.nStartItems = nStartItems; - this._rng = randomSeed === null ? seedrandom() : seedrandom(randomSeed); - } - - public get theta() { - return this._theta; - } - - public get seMeasurement() { - return this._seMeasurement; - } - - public get nItems() { - return this._resps.length; - } - - public get resps() { - return this._resps; - } - - public get zetas() { - return this._zetas; - } - - private static validateMethod(method: string) { - const lowerMethod = method.toLowerCase(); - const validMethods: Array = ['mle', 'eap']; // TO DO: add staircase - if (!validMethods.includes(lowerMethod)) { - throw new Error('The abilityEstimator you provided is not in the list of valid methods'); - } - return lowerMethod; - } - - private static validateItemSelect(itemSelect: string) { - const lowerItemSelect = itemSelect.toLowerCase(); - const validItemSelect: Array = ['mfi', 'random', 'closest', 'fixed']; - if (!validItemSelect.includes(lowerItemSelect)) { - throw new Error('The itemSelector you provided is not in the list of valid methods'); - } - return lowerItemSelect; - } - - private static validateStartSelect(startSelect: string) { - const lowerStartSelect = startSelect.toLowerCase(); - const validStartSelect: Array = ['random', 'middle', 'fixed']; // TO DO: add staircase - if (!validStartSelect.includes(lowerStartSelect)) { - throw new Error('The startSelect you provided is not in the list of valid methods'); - } - return lowerStartSelect; - } - - /** - * use previous response patterns and item params to calculate the estimate ability based on a defined method - * @param zeta - last item param - * @param answer - last response pattern - * @param method - */ - public updateAbilityEstimate(zeta: Zeta | Zeta[], answer: (0 | 1) | (0 | 1)[], method: string = this.method) { - method = Cat.validateMethod(method); - - zeta = Array.isArray(zeta) ? zeta : [zeta]; - answer = Array.isArray(answer) ? answer : [answer]; - - if (zeta.length !== answer.length) { - throw new Error('Unmatched length between answers and item params'); - } - this._zetas.push(...zeta); - this._resps.push(...answer); - - if (method === 'eap') { - this._theta = this.estimateAbilityEAP(); - } else if (method === 'mle') { - this._theta = this.estimateAbilityMLE(); - } - this.calculateSE(); - } - - private estimateAbilityEAP() { - let num = 0; - let nf = 0; - this.prior.forEach(([theta, probability]) => { - const like = this.likelihood(theta); - num += theta * like * probability; - nf += like * probability; - }); - - return num / nf; - } - - private estimateAbilityMLE() { - const theta0 = [0]; - const solution = minimize_Powell(this.negLikelihood.bind(this), theta0); - let theta = solution.argument[0]; - if (theta > this.maxTheta) { - theta = this.maxTheta; - } else if (theta < this.minTheta) { - theta = this.minTheta; - } - return theta; - } - - private negLikelihood(thetaArray: Array) { - return -this.likelihood(thetaArray[0]); - } - - private likelihood(theta: number) { - return this._zetas.reduce((acc, zeta, i) => { - const irf = itemResponseFunction(theta, zeta); - return this._resps[i] === 1 ? acc + Math.log(irf) : acc + Math.log(1 - irf); - }, 1); - } - - /** - * calculate the standard error of ability estimation - */ - private calculateSE() { - const sum = this._zetas.reduce((previousValue, zeta) => previousValue + fisherInformation(this._theta, zeta), 0); - this._seMeasurement = 1 / Math.sqrt(sum); - } - - /** - * find the next available item from an input array of stimuli based on a selection method - * - * remainingStimuli is sorted by fisher information to reduce the computation complexity for future item selection - * @param stimuli - an array of stimulus - * @param itemSelect - the item selection method - * @param deepCopy - default deepCopy = true - * @returns {nextStimulus: Stimulus, - remainingStimuli: Array} - */ - public findNextItem(stimuli: Stimulus[], itemSelect: string = this.itemSelect, deepCopy = true) { - let arr: Array; - let selector = Cat.validateItemSelect(itemSelect); - if (deepCopy) { - arr = cloneDeep(stimuli); - } else { - arr = stimuli; - } - if (this.nItems < this.nStartItems) { - selector = this.startSelect; - } - if (selector !== 'mfi' && selector !== 'fixed') { - // for mfi, we sort the arr by fisher information in the private function to select the best item, - // and then sort by difficulty to return the remainingStimuli - // for fixed, we want to keep the corpus order as input - arr.sort((a: Stimulus, b: Stimulus) => a.difficulty - b.difficulty); - } - - if (selector === 'middle') { - // middle will only be used in startSelect - return this.selectorMiddle(arr); - } else if (selector === 'closest') { - return this.selectorClosest(arr); - } else if (selector === 'random') { - return this.selectorRandom(arr); - } else if (selector === 'fixed') { - return this.selectorFixed(arr); - } else { - return this.selectorMFI(arr); - } - } - - private selectorMFI(arr: Stimulus[]) { - const stimuliAddFisher = arr.map((element: Stimulus) => ({ - fisherInformation: fisherInformation(this._theta, { - a: element.a || 1, - b: element.difficulty || 0, - c: element.c || 0, - d: element.d || 1, - }), - ...element, - })); - - stimuliAddFisher.sort((a, b) => b.fisherInformation - a.fisherInformation); - stimuliAddFisher.forEach((stimulus: Stimulus) => { - delete stimulus['fisherInformation']; - }); - return { - nextStimulus: stimuliAddFisher[0], - remainingStimuli: stimuliAddFisher.slice(1).sort((a: Stimulus, b: Stimulus) => a.difficulty - b.difficulty), - }; - } - - private selectorMiddle(arr: Stimulus[]) { - let index: number; - if (arr.length < this.nStartItems) { - index = Math.floor(arr.length / 2); - } else { - index = - Math.floor(arr.length / 2) + - this.randomInteger(-Math.floor(this.nStartItems / 2), Math.floor(this.nStartItems / 2)); - } - const nextItem = arr[index]; - arr.splice(index, 1); - return { - nextStimulus: nextItem, - remainingStimuli: arr, - }; - } - - private selectorClosest(arr: Stimulus[]) { - //findClosest requires arr is sorted by difficulty - const index = findClosest(arr, this._theta + 0.481); - const nextItem = arr[index]; - arr.splice(index, 1); - return { - nextStimulus: nextItem, - remainingStimuli: arr, - }; - } - - private selectorRandom(arr: Stimulus[]) { - const index = Math.floor(this._rng() * arr.length); - const nextItem = arr.splice(index, 1)[0]; - return { - nextStimulus: nextItem, - remainingStimuli: arr, - }; - } - - /** - * Picks the next item in line from the given list of stimuli. - * It grabs the first item from the list, removes it, and then returns it along with the rest of the list. - * - * @param arr - The list of stimuli to choose from. - * @returns {Object} - An object with the next item and the updated list. - * @returns {Stimulus} return.nextStimulus - The item that was picked from the list. - * @returns {Stimulus[]} return.remainingStimuli - The list of what's left after picking the item. - */ - private selectorFixed(arr: Stimulus[]) { - const nextItem = arr.shift() ?? null; - return { - nextStimulus: nextItem, - remainingStimuli: arr, - }; - } - - /** - * return a random integer between min and max - * @param min - The minimum of the random number range (include) - * @param max - The maximum of the random number range (include) - * @returns {number} - random integer within the range - */ - private randomInteger(min: number, max: number) { - return Math.floor(this._rng() * (max - min + 1)) + min; - } -} +export { Cat, CatInput } from './cat'; +export { Clowder, ClowderInput } from './clowder'; +export { prepareClowderCorpus } from './corpus'; +export { + EarlyStopping, + StopAfterNItems, + StopOnSEMeasurementPlateau, + StopIfSEMeasurementBelowThreshold, +} from './stopping'; diff --git a/src/stopping.ts b/src/stopping.ts new file mode 100644 index 0000000..937a01f --- /dev/null +++ b/src/stopping.ts @@ -0,0 +1,258 @@ +import { Cat } from './cat'; +import { CatMap } from './type'; +import _uniq from 'lodash/uniq'; + +/** + * Interface for input parameters to EarlyStopping classes. + */ +export interface EarlyStoppingInput { + /** The logical operation to use to evaluate multiple stopping conditions */ + logicalOperation?: 'and' | 'or' | 'only' | 'AND' | 'OR' | 'ONLY'; +} + +export interface StopAfterNItemsInput extends EarlyStoppingInput { + /** Number of items to require before stopping */ + requiredItems: CatMap; +} + +export interface StopOnSEMeasurementPlateauInput extends EarlyStoppingInput { + /** Number of items to wait for before triggering early stopping */ + patience: CatMap; + /** Tolerance for standard error of measurement drop */ + tolerance?: CatMap; +} + +export interface StopIfSEMeasurementBelowThresholdInput extends EarlyStoppingInput { + /** Stop if the standard error of measurement drops below this level */ + seMeasurementThreshold: CatMap; + /** Number of items to wait for before triggering early stopping */ + patience?: CatMap; + /** Tolerance for standard error of measurement drop */ + tolerance?: CatMap; +} + +/** + * Abstract class for early stopping strategies. + */ +export abstract class EarlyStopping { + protected _earlyStop: boolean; + protected _nItems: CatMap; + protected _seMeasurements: CatMap; + protected _logicalOperation: 'and' | 'or' | 'only'; + + constructor({ logicalOperation = 'or' }: EarlyStoppingInput) { + this._seMeasurements = {}; + this._nItems = {}; + this._earlyStop = false; + + if (!['and', 'or', 'only'].includes(logicalOperation.toLowerCase())) { + throw new Error(`Invalid logical operation. Expected "and", "or", or "only". Received "${logicalOperation}"`); + } + this._logicalOperation = logicalOperation.toLowerCase() as 'and' | 'or' | 'only'; + } + + public abstract get evaluationCats(): string[]; + + public get earlyStop() { + return this._earlyStop; + } + + public get nItems() { + return this._nItems; + } + + public get seMeasurements() { + return this._seMeasurements; + } + + public get logicalOperation() { + return this._logicalOperation; + } + + /** + * Update the internal state of the early stopping strategy based on the provided cats. + * @param {CatMap}cats - A map of cats to update. + */ + protected _updateCats(cats: CatMap) { + for (const catName in cats) { + const cat = cats[catName]; + const nItems = cat.nItems; + const seMeasurement = cat.seMeasurement; + + if (nItems > (this._nItems[catName] ?? 0)) { + this._nItems[catName] = nItems; + this._seMeasurements[catName] = [...(this._seMeasurements[catName] ?? []), seMeasurement]; + } + } + } + + /** + * Abstract method to be implemented by subclasses to evaluate a single stopping condition. + * @param {string} catToEvaluate - The name of the cat to evaluate for early stopping. + */ + protected abstract _evaluateStoppingCondition(catToEvaluate: string): boolean; + + /** + * Abstract method to be implemented by subclasses to update the early stopping strategy. + * @param {CatMap} cats - A map of cats to update. + */ + public update(cats: CatMap, catToSelect?: string): void { + this._updateCats(cats); // This updates internal state with current cat data + + // Collect the stopping conditions for all cats + const conditions: boolean[] = this.evaluationCats.map((catName) => this._evaluateStoppingCondition(catName)); + + // Evaluate the stopping condition based on the logical operation + if (this._logicalOperation === 'and') { + this._earlyStop = conditions.every(Boolean); // All conditions must be true for 'and' + } else if (this._logicalOperation === 'or') { + this._earlyStop = conditions.some(Boolean); // Any condition can be true for 'or' + } else if (this._logicalOperation === 'only') { + if (catToSelect === undefined) { + throw new Error('Must provide a cat to select for "only" stopping condition'); + } + + // Evaluate the stopping condition for the selected cat + if (this.evaluationCats.includes(catToSelect)) { + this._earlyStop = this._evaluateStoppingCondition(catToSelect); + } else { + this._earlyStop = false; // Default to false if the selected cat is not in evaluationCats + } + } + } +} + +/** + * Class implementing early stopping based on a plateau in standard error of measurement. + */ +export class StopOnSEMeasurementPlateau extends EarlyStopping { + protected _patience: CatMap; + protected _tolerance: CatMap; + + constructor(input: StopOnSEMeasurementPlateauInput) { + super(input); + this._patience = input.patience; + this._tolerance = input.tolerance ?? {}; + } + + public get evaluationCats() { + return _uniq([...Object.keys(this._patience), ...Object.keys(this._tolerance)]); + } + + public get patience() { + return this._patience; + } + + public get tolerance() { + return this._tolerance; + } + + protected _evaluateStoppingCondition(catToEvaluate: string) { + const seMeasurements = this._seMeasurements[catToEvaluate]; + + // Use MAX_SAFE_INTEGER and MAX_VALUE to prevent early stopping if the `catToEvaluate` is missing from the cats map. + const patience = this._patience[catToEvaluate]; + const tolerance = this._tolerance[catToEvaluate]; + + let earlyStop = false; + + if (seMeasurements?.length >= patience) { + const mean = seMeasurements.slice(-patience).reduce((sum, se) => sum + se, 0) / patience; + const withinTolerance = seMeasurements.slice(-patience).every((se) => Math.abs(se - mean) <= tolerance); + + if (withinTolerance) { + earlyStop = true; + } + } + + return earlyStop; + } +} + +/** + * Class implementing early stopping after a certain number of items. + */ +export class StopAfterNItems extends EarlyStopping { + protected _requiredItems: CatMap; + + constructor(input: StopAfterNItemsInput) { + super(input); + this._requiredItems = input.requiredItems; + } + + public get requiredItems() { + return this._requiredItems; + } + + public get evaluationCats() { + return Object.keys(this._requiredItems); + } + + protected _evaluateStoppingCondition(catToEvaluate: string) { + const requiredItems = this._requiredItems[catToEvaluate]; + const nItems = this._nItems[catToEvaluate]; + + let earlyStop = false; + + if (nItems >= requiredItems) { + earlyStop = true; + } + + return earlyStop; + } +} + +/** + * Class implementing early stopping if the standard error of measurement drops below a certain threshold. + */ +export class StopIfSEMeasurementBelowThreshold extends EarlyStopping { + protected _patience: CatMap; + protected _tolerance: CatMap; + protected _seMeasurementThreshold: CatMap; + + constructor(input: StopIfSEMeasurementBelowThresholdInput) { + super(input); + this._seMeasurementThreshold = input.seMeasurementThreshold; + this._patience = input.patience ?? {}; + this._tolerance = input.tolerance ?? {}; + } + + public get patience() { + return this._patience; + } + + public get tolerance() { + return this._tolerance; + } + + public get seMeasurementThreshold() { + return this._seMeasurementThreshold; + } + + public get evaluationCats() { + return _uniq([ + ...Object.keys(this._patience), + ...Object.keys(this._tolerance), + ...Object.keys(this._seMeasurementThreshold), + ]); + } + + protected _evaluateStoppingCondition(catToEvaluate: string) { + const seMeasurements = this._seMeasurements[catToEvaluate] ?? []; + const seThreshold = this._seMeasurementThreshold[catToEvaluate] ?? 0; + const patience = this._patience[catToEvaluate] ?? 1; + const tolerance = this._tolerance[catToEvaluate] ?? 0; + + let earlyStop = false; + + if (seMeasurements.length >= patience) { + const withinTolerance = seMeasurements.slice(-patience).every((se) => se - seThreshold <= tolerance); + + if (withinTolerance) { + earlyStop = true; + } + } + + return earlyStop; + } +} diff --git a/src/type.ts b/src/type.ts index 36b4793..739c65d 100644 --- a/src/type.ts +++ b/src/type.ts @@ -1,7 +1,40 @@ -export type Zeta = { a: number; b: number; c: number; d: number }; +export type ZetaSymbolic = { + // Symbolic parameter names + a: number; // Discrimination (slope of the curve) + b: number; // Difficulty (location of the curve) + c: number; // Guessing (lower asymptote) + d: number; // Slipping (upper asymptote) +}; -export interface Stimulus { - difficulty: number; +export interface Zeta { + // Symbolic parameter names + a?: number; // Discrimination (slope of the curve) + b?: number; // Difficulty (location of the curve) + c?: number; // Guessing (lower asymptote) + d?: number; // Slipping (upper asymptote) + // Semantic parameter names + discrimination?: number; + difficulty?: number; + guessing?: number; + slipping?: number; +} + +export interface Stimulus extends Zeta { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + [key: string]: any; +} + +export type ZetaCatMap = { + cats: string[]; + zeta: Zeta; +}; + +export interface MultiZetaStimulus { + zetas: ZetaCatMap[]; // eslint-disable-next-line @typescript-eslint/no-explicit-any [key: string]: any; } + +export type CatMap = { + [name: string]: T; +}; diff --git a/src/utils.ts b/src/utils.ts index f7a8d8c..b2c276f 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,35 +1,43 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ import bs from 'binary-search'; -import { Stimulus, Zeta } from './type'; +import { Stimulus, Zeta, ZetaSymbolic } from './type'; +import { fillZetaDefaults } from './corpus'; /** - * calculates the probability that someone with a given ability level theta will answer correctly an item. Uses the 4 parameters logistic model - * @param theta - ability estimate - * @param zeta - item params + * Calculates the probability that someone with a given ability level theta will + * answer correctly an item. Uses the 4 parameters logistic model + * + * @param {number} theta - ability estimate + * @param {Zeta} zeta - item params * @returns {number} the probability */ export const itemResponseFunction = (theta: number, zeta: Zeta) => { - return zeta.c + (zeta.d - zeta.c) / (1 + Math.exp(-zeta.a * (theta - zeta.b))); + const _zeta = fillZetaDefaults(zeta, 'symbolic') as ZetaSymbolic; + return _zeta.c + (_zeta.d - _zeta.c) / (1 + Math.exp(-_zeta.a * (theta - _zeta.b))); }; /** - * a 3PL Fisher information function - * @param theta - ability estimate - * @param zeta - item params + * A 3PL Fisher information function + * + * @param {number} theta - ability estimate + * @param {Zeta} zeta - item params * @returns {number} - the expected value of the observed information */ export const fisherInformation = (theta: number, zeta: Zeta) => { - const p = itemResponseFunction(theta, zeta); + const _zeta = fillZetaDefaults(zeta, 'symbolic') as ZetaSymbolic; + const p = itemResponseFunction(theta, _zeta); const q = 1 - p; - return Math.pow(zeta.a, 2) * (q / p) * (Math.pow(p - zeta.c, 2) / Math.pow(1 - zeta.c, 2)); + return Math.pow(_zeta.a, 2) * (q / p) * (Math.pow(p - _zeta.c, 2) / Math.pow(1 - _zeta.c, 2)); }; /** - * return a Gaussian distribution within a given range - * @param mean - * @param stdDev - * @param min - * @param max - * @param stepSize - the quantization (step size) of the internal table, default = 0.1 + * Return a Gaussian distribution within a given range + * + * @param {number} mean + * @param {number} stdDev + * @param {number} min + * @param {number} max + * @param {number} stepSize - the quantization (step size) of the internal table, default = 0.1 * @returns {Array<[number, number]>} - a normal distribution */ export const normal = (mean = 0, stdDev = 1, min = -4, max = 4, stepSize = 0.1) => { @@ -45,27 +53,28 @@ export const normal = (mean = 0, stdDev = 1, min = -4, max = 4, stepSize = 0.1) }; /** - * find the item in a given array that has the difficulty closest to the target value + * Find the item in a given array that has the difficulty closest to the target value * * @remarks * The input array of stimuli must be sorted by difficulty. * - * @param arr Array - an array of stimuli sorted by difficulty - * @param target number - ability estimate - * @returns {number} the index of arr + * @param {Stimulus[]} inputStimuli - an array of stimuli sorted by difficulty + * @param {number} target - ability estimate + * @returns {number} the index of stimuli */ -export const findClosest = (arr: Array, target: number) => { +export const findClosest = (inputStimuli: Array, target: number) => { + const stimuli = inputStimuli.map((stim) => fillZetaDefaults(stim, 'semantic')); // Let's consider the edge cases first - if (target <= arr[0].difficulty) { + if (target <= stimuli[0].difficulty!) { return 0; - } else if (target >= arr[arr.length - 1].difficulty) { - return arr.length - 1; + } else if (target >= stimuli[stimuli.length - 1].difficulty!) { + return stimuli.length - 1; } const comparitor = (element: Stimulus, needle: number) => { - return element.difficulty - needle; + return element.difficulty! - needle; }; - const indexOfTarget = bs(arr, target, comparitor); + const indexOfTarget = bs(stimuli, target, comparitor); if (indexOfTarget >= 0) { // `bs` returns a positive integer index if it found an exact match. @@ -79,8 +88,8 @@ export const findClosest = (arr: Array, target: number) => { // So we simply compare the differences between the target and the high and // low values, respectively - const lowDiff = Math.abs(arr[lowIndex].difficulty - target); - const highDiff = Math.abs(arr[highIndex].difficulty - target); + const lowDiff = Math.abs(stimuli[lowIndex].difficulty! - target); + const highDiff = Math.abs(stimuli[highIndex].difficulty! - target); if (lowDiff < highDiff) { return lowIndex;