Skip to content

Commit

Permalink
[inferno-ml] Make sure evaluation environment is always valid (#153)
Browse files Browse the repository at this point in the history
Currently, if part of the `EvaluationEnvironment` is missing, e.g.
inputs or outputs, script evaluation will fail. In this case, the most
sensible thing to do is just use whatever is linked to the param. This
allows for selectively overriding certain parts of the environment
without the whole thing failing
  • Loading branch information
ngua authored Jan 21, 2025
1 parent 1ed2b66 commit c323f3d
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 37 deletions.
1 change: 0 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ jobs:
with:
name: inferno
authToken: "${{ secrets.CACHIX_TOKEN }}"
- uses: DeterminateSystems/magic-nix-cache-action@main

# Build inferno and run all tests
- run: |
Expand Down
3 changes: 3 additions & 0 deletions inferno-ml-server-types/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Revision History for inferno-ml-server-types
*Note*: we use https://pvp.haskell.org/ (MAJOR.MAJOR.MINOR.PATCH)

## 0.12.1
* Add some convenience type synonyms

## 0.12.0
* Add creation date to models and versions

Expand Down
24 changes: 18 additions & 6 deletions inferno-ml-server-types/src/Inferno/ML/Server/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,18 @@ type BridgeAPI p t =
-- This means the same output may appear more than once in the stream
type WriteStream m = ConduitT () (Int, [(EpochTime, IValue)]) m ()

-- | Convenience synonym for the consumed 'WriteStream'
type Writes p = Map p [(Double, EpochTime)]

-- | Just a convenience synonym for cleaning up type signatures
type Inputs p = Map Ident (SingleOrMany p)

-- | Same as 'Inputs' above
type Outputs p = Map Ident (SingleOrMany p)

-- | Same as 'Inputs', 'Outputs'
type Models a = Map Ident a

data ServerStatus
= Idle
| EvaluatingScript
Expand Down Expand Up @@ -708,8 +720,8 @@ data InferenceParam gid p = InferenceParam
-- | Mapping the input\/output to the Inferno identifier helps ensure that
-- Inferno identifiers are always pointing to the correct input\/output;
-- otherwise we would need to rely on the order of the original identifiers
inputs :: Map Ident (SingleOrMany p),
outputs :: Map Ident (SingleOrMany p),
inputs :: Inputs p,
outputs :: Outputs p,
-- | Resolution, passed to bridge routes
resolution :: Word64,
-- | The time that this parameter was \"deleted\", if any. For active
Expand Down Expand Up @@ -795,7 +807,7 @@ instance (Arbitrary gid, Arbitrary p) => ToADTArbitrary (InferenceParam gid p) w
-- linked to it indirectly via its script. This is provided for convenience
data InferenceParamWithModels gid p = InferenceParamWithModels
{ param :: InferenceParam gid p,
models :: Map Ident (Id (ModelVersion gid Oid))
models :: Models (Id (ModelVersion gid Oid))
}
deriving stock (Show, Eq, Generic)

Expand Down Expand Up @@ -1023,9 +1035,9 @@ instance Ord a => Ord (SingleOrMany a) where
-- evaluator. This allows for more interactive testing
data EvaluationEnv gid p = EvaluationEnv
{ script :: VCObjectHash,
inputs :: Map Ident (SingleOrMany p),
outputs :: Map Ident (SingleOrMany p),
models :: Map Ident (Id (ModelVersion gid Oid))
inputs :: Inputs p,
outputs :: Inputs p,
models :: Models (Id (ModelVersion gid Oid))
}
deriving stock (Show, Eq, Generic)
deriving anyclass (FromJSON, ToJSON, ToADTArbitrary)
Expand Down
3 changes: 3 additions & 0 deletions inferno-ml-server/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Revision History for `inferno-ml-server`

## 2025.12.21
* Always make sure relevant items are set in evaluation environment (testing route)

## 2024.11.29
* Use `Pool` to hold Postgres connections

Expand Down
4 changes: 2 additions & 2 deletions inferno-ml-server/exe/ParseAndSave.hs
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ funs = BridgeFuns notSupported notSupported notSupported notSupported
notSupported = error "Not supported"

data InputsOutputs = InputsOutputs
{ inputs :: Map Ident (SingleOrMany PID),
outputs :: Map Ident (SingleOrMany PID)
{ inputs :: Inputs PID,
outputs :: Outputs PID
}
deriving stock (Generic)
deriving anyclass (FromJSON)
56 changes: 35 additions & 21 deletions inferno-ml-server/src/Inferno/ML/Server/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ import Inferno.ML.Types.Value (MlValue (VExtended, VModelName))
import Inferno.Types.Syntax
( Expr (App, Var),
ExtIdent (ExtIdent),
Ident,
ImplExpl (Expl),
Scoped (LocalScope),
)
Expand Down Expand Up @@ -116,7 +115,7 @@ runInferenceParam ::
runInferenceParam ipid mres uuid =
runInferenceParamWithEnv ipid uuid
=<< mkScriptEnv
=<< getParameterWithModels ipid
=<< getInferenceParamWithModels ipid
where
mkScriptEnv :: InferenceParamWithModels -> RemoteM ScriptEnv
mkScriptEnv pwm =
Expand All @@ -140,24 +139,40 @@ testInferenceParam ipid mres uuid eenv =
=<< mkScriptEnv
-- Just need to get the param, we already have the model information
-- from the overrides
=<< getParam
=<< getInferenceParamWithModels ipid
where
-- Note that, unlike `runInferenceParam`, several of the items required
-- for script eval come from the `EvaluationEnv`
mkScriptEnv :: InferenceParam -> RemoteM ScriptEnv
mkScriptEnv param =
ScriptEnv param eenv.models eenv.inputs eenv.outputs
-- for script eval MAY come from the `EvaluationEnv` if they have been
-- overridden. See the bindings for `inputs`, `outputs`, and `models` below
-- for an explanation
mkScriptEnv :: InferenceParamWithModels -> RemoteM ScriptEnv
mkScriptEnv pwm =
ScriptEnv pwm.param models inputs outputs
<$> getVcObject eenv.script
?? eenv.script
?? mres

getParam :: RemoteM InferenceParam
getParam =
firstOrThrow (NoSuchParameter ipid)
=<< queryStore q (Only ipid)
where
q :: Query
q = [sql| SELECT * FROM params WHERE id = ? |]
-- If the `inputs` have not been specified in the evaluation env, i.e.
-- the inputs are not being overridden, use the ones that are linked
-- directly to the param
inputs :: Inputs PID
inputs
| null eenv.inputs = pwm.param.inputs
| otherwise = eenv.inputs

-- Likewise, if the `outputs` have not been overridden, use the ones
-- that are linked directly to the param
outputs :: Outputs PID
outputs
| null eenv.outputs = pwm.param.outputs
| otherwise = eenv.outputs

-- Likewise, if the `models` have not been overridden, use the ones
-- that are linked directly to the param via its inference script
models :: Models (Id ModelVersion)
models
| null eenv.models = pwm.models
| otherwise = eenv.models

runInferenceParamWithEnv ::
Id InferenceParam ->
Expand Down Expand Up @@ -433,8 +448,8 @@ getVcObject vch =
q :: Query
q = [sql| SELECT * FROM scripts WHERE id = ? |]

getParameterWithModels :: Id InferenceParam -> RemoteM InferenceParamWithModels
getParameterWithModels ipid =
getInferenceParamWithModels :: Id InferenceParam -> RemoteM InferenceParamWithModels
getInferenceParamWithModels ipid =
fmap
( uncurry InferenceParamWithModels
. fmap (getAeson . fromOnly)
Expand Down Expand Up @@ -490,8 +505,7 @@ getParameterWithModels ipid =
--
-- NOTE: This action assumes that the current working directory is the model
-- cache! It can be run using e.g. 'withCurrentDirectory'
getAndCacheModels ::
ModelCache -> Map Ident (Id ModelVersion) -> RemoteM ()
getAndCacheModels :: ModelCache -> Models (Id ModelVersion) -> RemoteM ()
getAndCacheModels cache =
traverse_ (uncurry copyAndCache)
<=< getModelsAndVersions
Expand Down Expand Up @@ -586,9 +600,9 @@ mkModelPath = (<.> "ts" <.> "pt") . UUID.toString . wrappedTo
-- endpoint, these will be overridden
data ScriptEnv = ScriptEnv
{ param :: InferenceParam,
models :: Map Ident (Id ModelVersion),
inputs :: Map Ident (SingleOrMany PID),
outputs :: Map Ident (SingleOrMany PID),
models :: Models (Id ModelVersion),
inputs :: Inputs PID,
outputs :: Outputs PID,
obj :: VCMeta VCObject,
script :: VCObjectHash,
mres :: Maybe Int64
Expand Down
10 changes: 4 additions & 6 deletions inferno-ml-server/src/Inferno/ML/Server/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ import qualified Data.ByteString.Char8 as ByteString.Char8
import Data.Data (Typeable)
import Data.Generics.Labels ()
import Data.Generics.Wrapped (wrappedTo)
import Data.Map.Strict (Map)
import Data.Pool (Pool)
import qualified Data.Pool as Pool
import Data.Scientific (Scientific)
Expand Down Expand Up @@ -91,7 +90,6 @@ import "inferno-ml-server-types" Inferno.ML.Server.Types as M hiding
ModelVersion,
)
import qualified "inferno-ml-server-types" Inferno.ML.Server.Types as Types
import Inferno.Types.Syntax (Ident)
import Inferno.VersionControl.Types
( VCObject,
VCObjectHash,
Expand Down Expand Up @@ -297,7 +295,7 @@ instance FromJSON ScriptType where
_ -> pure OtherScript

newtype InferenceOptions = InferenceOptions
{ models :: Map Ident (Id ModelVersion)
{ models :: Models (Id ModelVersion)
}
deriving stock (Show, Eq, Generic)
deriving anyclass (FromJSON, ToJSON)
Expand Down Expand Up @@ -442,8 +440,8 @@ pattern InferenceScript h o = Types.InferenceScript h o
pattern InferenceParam ::
Maybe (Id InferenceParam) ->
VCObjectHash ->
Map Ident (SingleOrMany PID) ->
Map Ident (SingleOrMany PID) ->
Inputs PID ->
Outputs PID ->
Word64 ->
Maybe UTCTime ->
EntityId GId ->
Expand All @@ -452,7 +450,7 @@ pattern InferenceParam iid s is os res mt gid =
Types.InferenceParam iid s is os res mt gid

pattern InferenceParamWithModels ::
InferenceParam -> Map Ident (Id ModelVersion) -> InferenceParamWithModels
InferenceParam -> Models (Id ModelVersion) -> InferenceParamWithModels
pattern InferenceParamWithModels ip mvs = Types.InferenceParamWithModels ip mvs

pattern BridgeInfo :: Id InferenceParam -> IPv4 -> Word64 -> BridgeInfo
Expand Down
2 changes: 1 addition & 1 deletion inferno-ml-server/test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ mkCacheSpec env = Hspec.before_ clearCache . Hspec.describe "Model cache" $ do
cdCache :: IO a -> IO a
cdCache = withCurrentDirectory env.config.cache.path

modelsWithIdents :: Map Ident (Id ModelVersion)
modelsWithIdents :: Models (Id ModelVersion)
modelsWithIdents = Map.singleton "dummy" mnistV1

mkDbSpec :: Env -> Spec
Expand Down

0 comments on commit c323f3d

Please sign in to comment.