From 25e359ed68984e538a447d056a46a4172552b266 Mon Sep 17 00:00:00 2001 From: Rory Tyler Hayford <52039264+ngua@users.noreply.github.com> Date: Mon, 23 Sep 2024 18:47:15 +0700 Subject: [PATCH] [inferno-ml] Simplify model permissions, add more metadata (#132) Includes several changes to the `Model` and related types: - Model permissions are now expressed as a single group ID (which simplifies things considerably) - Models now contain `visibility` and `updated` fields, similar to Inferno VC scripts - Some smaller changes to `ModelVersion` (adding a high-level description, etc...) - Removal of `users` table, which isn't necessary Since we don't need to support GHC 8 any longer, I've switched to using `-XOverloadedRecordDot` in some places. I'm going to remove GHC 8 support entirely when I finish the compiler upgrade (blocked on NVIDIA pain at the moment). --- .../src/Inferno/ML/Server/Types.hs | 277 ++++++------------ .../src/Inferno/ML/Server/Types.hs | 4 +- inferno-ml-server/test/Main.hs | 1 - nix/default.nix | 7 +- .../migrations/v1-create-tables.sql | 23 +- nix/inferno-ml/tests/server.nix | 30 +- 6 files changed, 116 insertions(+), 226 deletions(-) diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs index 98259924..7aae8cee 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs @@ -7,6 +7,7 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -Wno-unticked-promoted-constructors #-} @@ -26,7 +27,7 @@ import Data.ByteString (ByteString) import qualified Data.ByteString.Char8 as ByteString.Char8 import Data.Char (chr) import Data.Data (Typeable) -import Data.Generics.Product (HasType (typed), the) +import Data.Generics.Product (HasType (typed)) import Data.Generics.Wrapped (wrappedFrom, wrappedTo) import Data.Hashable (Hashable) import qualified Data.IP @@ -73,7 +74,7 @@ import Inferno.Types.VersionControl byteStringToVCObjectHash, vcObjectHashToByteString, ) -import Inferno.VersionControl.Types (VCMeta, VCObject) +import Inferno.VersionControl.Types (VCMeta, VCObject, VCObjectVisibility) import Lens.Micro.Platform hiding ((.=)) import Servant ( Capture, @@ -184,9 +185,9 @@ instance FromRow (BridgeInfo uid gid p s) where instance ToRow (BridgeInfo uid gid p s) where toRow bi = - [ bi ^. the @"id" & toField, - bi ^. the @"host" & toField, - bi ^. the @"port" & toField + [ bi.id & toField, + bi.host & toField, + bi.port & toField ] -- | The ID of a database entity @@ -242,8 +243,8 @@ instance ToField VCObjectHashRow where instance (ToJSON uid, ToJSON gid) => ToRow (InferenceScript uid gid) where toRow s = -- NOTE: Don't change the order! - [ s ^. the @"hash" & VCObjectHashRow & toField, - s ^. the @"obj" & Aeson & toField + [ s.hash & VCObjectHashRow & toField, + s.obj & Aeson & toField ] instance @@ -272,118 +273,91 @@ instance -- versions, e.g. model name and permissions. A second table, 'ModelVersion', -- contains the specific versions of each model (and the actual model contents), -- along with other metadata that may change between versions -data Model uid gid = Model - { id :: Maybe (Id (Model uid gid)), +data Model gid = Model + { id :: Maybe (Id (Model gid)), name :: Text, - -- | Permissions for reading or updating the model, keyed by the group ID - -- type - -- - -- NOTE: This is stored as a @jsonb@ rather than as @hstore@. It could - -- currently be stored as an @hstore@, but later we might want to - -- use a more complex type that we could not easily convert to\/from - -- text (which is required to use @hstore@). So using @jsonb@ allows - -- for greater potential flexibility - permissions :: Map gid ModelPermissions, - -- | The user who owns the model, if any. Note that owning a model - -- will implicitly set permissions - user :: Maybe uid, + -- | The group that owns this model + gid :: gid, + -- | Analogous to visibility of @inferno-vc@ scripts + visibility :: VCObjectVisibility, + -- | The last time the model was updated (i.e. a new version was created), + -- if any + updated :: Maybe UTCTime, -- | The time that this model was \"deleted\", if any. For active models, -- this will be @Nothing@ terminated :: Maybe UTCTime } deriving stock (Show, Eq, Generic) -instance NFData (Model uid gid) where +instance NFData (Model gid) where rnf = rwhnf instance ( Typeable gid, - FromField uid, FromField gid, - FromJSONKey gid, Ord gid ) => - FromRow (Model uid gid) + FromRow (Model gid) where -- NOTE: Order of fields must align exactly with DB schema fromRow = Model <$> field <*> field + <*> field <*> fmap getAeson field <*> field <*> field -instance - ( ToField uid, - ToField gid, - ToJSONKey gid - ) => - ToRow (Model uid gid) - where +instance ToField gid => ToRow (Model gid) where -- NOTE: Order of fields must align exactly with DB schema toRow m = [ toField Default, - m ^. the @"name" & toField, - m ^. the @"permissions" & Aeson & toField, - m ^. the @"user" & toField, - -- The `ToRow` instance is only for new rows, so we don't want - -- to set the `terminated` field to anything by default + m.name & toField, + m.gid & toField, + m.visibility & Aeson & toField, + -- The `ToRow` instance is only for new rows, so we don't want to set + -- the `updated` and `terminated` fields to anything by default -- - -- The same applies to the other `toField Default`s for different - -- types below + -- The same rationale applies to the other `toField Default`s for + -- different types below + toField Default, toField Default ] {- ORMOLU_DISABLE -} instance - ( FromJSON uid, - FromJSONKey gid, + ( FromJSON gid, Ord gid ) => - FromJSON (Model uid gid) + FromJSON (Model gid) where parseJSON = withObject "Model" $ \o -> Model -- If a new model is being created, its ID will not be present <$> o .:? "id" - <*> (ensureNotNull =<< o .: "name") - <*> o .: "permissions" - <*> o .:? "user" + <*> o .: "name" + <*> o .: "gid" + <*> o .: "visibility" + <*> o .:? "updated" -- If a new model is being serialized, it does not really make -- sense to require a `"terminated": null` field <*> o .:? "terminated" - where - ensureNotNull :: Text -> Parser Text - ensureNotNull - t - | Text.null t = fail "Field cannot be empty" - | otherwise = pure t {- ORMOLU_ENABLE -} -instance - ( ToJSON uid, - ToJSONKey gid - ) => - ToJSON (Model uid gid) - where +instance ToJSON gid => ToJSON (Model gid) where toJSON m = object - [ "id" .= view (the @"id") m, - "name" .= view (the @"name") m, - "permissions" .= view (the @"permissions") m, - "user" .= view (the @"user") m, - "terminated" .= view (the @"terminated") m + [ "id" .= m.id, + "name" .= m.name, + "gid" .= m.gid, + "visibility" .= m.visibility, + "updated" .= m.updated, + "terminated" .= m.terminated ] -- Not derived generically in order to use special `Gen UTCTime` -instance - ( Ord gid, - Arbitrary gid, - Arbitrary uid - ) => - Arbitrary (Model uid gid) - where +instance (Ord gid, Arbitrary gid) => Arbitrary (Model gid) where arbitrary = Model <$> arbitrary @@ -391,12 +365,10 @@ instance <*> arbitrary <*> arbitrary <*> genMUtc + <*> genMUtc -- Can't be derived because there is (intentially) no `Arbitrary UTCTime` in scope -instance - (Arbitrary uid, Arbitrary gid, Ord gid) => - ToADTArbitrary (Model uid gid) - where +instance (Arbitrary gid, Ord gid) => ToADTArbitrary (Model gid) where toADTArbitrarySingleton _ = ADTArbitrarySingleton "Inferno.ML.Server.Types" "Model" . ConstructorArbitraryPair "Model" @@ -412,11 +384,12 @@ instance -- content, which will normally be an 'Oid' (Postgres large object). Other -- model metadata is contained here as well, e.g. the model card, as this -- might change between versions -data ModelVersion uid gid c = ModelVersion - { id :: Maybe (Id (ModelVersion uid gid c)), +data ModelVersion gid c = ModelVersion + { id :: Maybe (Id (ModelVersion gid c)), -- | Foreign key of the @model@ table, which contains invariant metadata -- related to the model, i.e. name, permissions, user - model :: Id (Model uid gid), + model :: Id (Model gid), + description :: Text, card :: ModelCard, -- | The actual contents of version of the model. Normally this will be -- an 'Oid' pointing to the serialized bytes of the model imported into @@ -431,12 +404,7 @@ data ModelVersion uid gid c = ModelVersion -- NOTE: This may require an orphan instance for the `c` type variable deriving anyclass (NFData) -instance - ( FromField uid, - FromField gid - ) => - FromRow (ModelVersion uid gid Oid) - where +instance FromField gid => FromRow (ModelVersion gid Oid) where -- NOTE: Order of fields must align exactly with DB schema. This instance -- could just be `anyclass` derived but it's probably better to be as -- explicit as possible @@ -448,34 +416,27 @@ instance <*> field <*> field <*> field + <*> field -instance - ( ToField uid, - ToField gid - ) => - ToRow (ModelVersion uid gid Oid) - where +instance ToField gid => ToRow (ModelVersion gid Oid) where -- NOTE: Order of fields must align exactly with DB schema toRow mv = [ toField Default, - mv ^. the @"model" & toField, - mv ^. the @"card" & Aeson & toField, - mv ^. the @"contents" & toField, - mv ^. the @"version" & toField, + mv.model & toField, + mv.description & toField, + mv.card & Aeson & toField, + mv.contents & toField, + mv.version & toField, toField Default ] {- ORMOLU_DISABLE -} -instance - ( FromJSON uid, - FromJSON gid - ) => - FromJSON (ModelVersion uid gid Oid) - where +instance FromJSON gid => FromJSON (ModelVersion gid Oid) where parseJSON = withObject "ModelVersion" $ \o -> ModelVersion <$> o .:? "id" <*> o .: "model" + <*> o .: "description" <*> o .: "card" <*> fmap (Oid . fromIntegral @Word64) (o .: "contents") <*> o .: "version" @@ -484,27 +445,23 @@ instance <*> o .:? "terminated" {- ORMOLU_ENABLE -} -instance - ( ToJSON uid, - ToJSON gid - ) => - ToJSON (ModelVersion uid gid Oid) - where +instance ToJSON gid => ToJSON (ModelVersion gid Oid) where toJSON mv = object - [ "id" .= view (the @"id") mv, - "model" .= view (the @"model") mv, - "contents" .= view (the @"contents" . to unOid) mv, - "version" .= view (the @"version") mv, - "card" .= view (the @"card") mv, - "terminated" .= view (the @"terminated") mv + [ "id" .= mv.id, + "model" .= mv.model, + "description" .= mv.description, + "contents" .= unOid mv.contents, + "version" .= mv.version, + "card" .= mv.card, + "terminated" .= mv.terminated ] where unOid :: Oid -> Word32 unOid (Oid (CUInt x)) = x -- Not derived generically in order to use special `Gen UTCTime` -instance Arbitrary c => Arbitrary (ModelVersion uid gid c) where +instance Arbitrary c => Arbitrary (ModelVersion gid c) where arbitrary = ModelVersion <$> arbitrary @@ -512,10 +469,11 @@ instance Arbitrary c => Arbitrary (ModelVersion uid gid c) where <*> arbitrary <*> arbitrary <*> arbitrary + <*> arbitrary <*> genMUtc -- Can't be derived because there is (intentially) no `Arbitrary UTCTime` in scope -instance (Arbitrary c) => ToADTArbitrary (ModelVersion uid gid c) where +instance Arbitrary c => ToADTArbitrary (ModelVersion gid c) where toADTArbitrarySingleton _ = ADTArbitrarySingleton "Inferno.ML.Server.Types" "ModelVersion" . ConstructorArbitraryPair "ModelVersion" @@ -525,34 +483,10 @@ instance (Arbitrary c) => ToADTArbitrary (ModelVersion uid gid c) where ADTArbitrary "Inferno.ML.Server.Types" "ModelVersion" <$> sequence [ConstructorArbitraryPair "ModelVersion" <$> arbitrary] --- | Permissions for reading or writing a model -data ModelPermissions - = -- | The model can be read e.g. for inference - ReadModel - | -- | The model can be updated e.g. during training - WriteModel - deriving stock (Show, Eq, Generic) - deriving anyclass (NFData, ToADTArbitrary) - -instance FromJSON ModelPermissions where - parseJSON = withText "ModelPermissions" $ \case - "read" -> pure ReadModel - "write" -> pure WriteModel - t -> fail $ unwords ["Invalid model permissions:", Text.unpack t] - -instance ToJSON ModelPermissions where - toJSON = - String . \case - ReadModel -> "read" - WriteModel -> "write" - -instance Arbitrary ModelPermissions where - arbitrary = genericArbitrary - -- | Full description and metadata of the model data ModelCard = ModelCard { -- | High-level, structured overview of model details and summary - description :: ModelDescription, + summary :: ModelSummary, metadata :: ModelMetadata } deriving stock (Show, Eq, Generic) @@ -562,40 +496,35 @@ data ModelCard = ModelCard instance Arbitrary ModelCard where arbitrary = genericArbitrary --- | Structured description of a model -data ModelDescription = ModelDescription - { -- | General summary of model, cannot be empty +-- | Structured summary of a model +data ModelSummary = ModelSummary + { -- | General summary of model (longer than top-level @description@ field + -- of 'ModelVersion' type) summary :: Text, -- | How the model is intended to be used uses :: Text, - -- | Applicable limitations, risks, biases, etc... - risks :: Text, - -- | Details on training data, speed\/size of training elements, etc... - training :: Text, evaluation :: Text } deriving stock (Show, Eq, Generic) deriving anyclass (ToJSON, NFData, ToADTArbitrary) {- ORMOLU_DISABLE -} -instance FromJSON ModelDescription where - parseJSON = withObject "ModelDescription" $ \o -> - ModelDescription +instance FromJSON ModelSummary where + parseJSON = withObject "ModelSummary" $ \o -> + ModelSummary <$> o .: "summary" <*> o .:? "uses" .!= mempty - <*> o .:? "risks" .!= mempty - <*> o .:? "training" .!= mempty <*> o .:? "evaluation" .!= mempty {- ORMOLU_ENABLE -} -instance Arbitrary ModelDescription where +instance Arbitrary ModelSummary where arbitrary = genericArbitrary -- | Metadata for the model, inspired by Hugging Face model card format data ModelMetadata = ModelMetadata { categories :: Vector Int, - datasets :: Vector Text, - metrics :: Vector Text, + datasets :: Text, + metrics :: Text, baseModel :: Maybe Text } deriving stock (Show, Eq, Generic) @@ -746,7 +675,7 @@ data InferenceParam uid gid p s = InferenceParam -- | The time that this parameter was \"deleted\", if any. For active -- parameters, this will be @Nothing@ terminated :: Maybe UTCTime, - user :: uid + uid :: uid } deriving stock (Show, Eq, Generic) deriving anyclass (NFData, ToJSON) @@ -768,7 +697,7 @@ instance <*> o .:? "resolution" .!= 128 -- We shouldn't require this field <*> o .:? "terminated" - <*> o .: "user" + <*> o .: "uid" {- ORMOLU_ENABLE -} -- We only want this instance if the `script` is a `VCObjectHash` (because it @@ -800,11 +729,11 @@ instance -- NOTE: Do not change the order of the field actions toRow ip = [ toField Default, - ip ^. the @"script" & VCObjectHashRow & toField, - ip ^. the @"inputs" & Aeson & toField, - ip ^. the @"resolution" & Aeson & toField, + ip.script & VCObjectHashRow & toField, + ip.inputs & Aeson & toField, + ip.resolution & Aeson & toField, toField Default, - ip ^. the @"user" & toField + ip.uid & toField ] -- Not derived generically in order to use special `Gen UTCTime` @@ -848,7 +777,7 @@ data InferenceParamWithModels uid gid p s = InferenceParamWithModels models :: Map Ident - ( Id (ModelVersion uid gid Oid), + ( Id (ModelVersion gid Oid), -- Name of parent model Text ) @@ -926,12 +855,12 @@ instance FromRow (EvaluationInfo uid gid p) where instance ToRow (EvaluationInfo uid gid p) where toRow ei = - [ ei ^. the @"id" & toField, - ei ^. the @"param" & toField, - ei ^. the @"start" & toField, - ei ^. the @"end" & toField, - ei ^. the @"allocated" & toField, - ei ^. the @"cpu" & toField + [ ei.id & toField, + ei.param & toField, + ei.start & toField, + ei.end & toField, + ei.allocated & toField, + ei.cpu & toField ] -- Not derived generically in order to use special `Gen UTCTime` @@ -945,24 +874,6 @@ instance Arbitrary (EvaluationInfo uid gid p) where <*> arbitrary <*> arbitrary --- | A user, parameterized by the user and group types -data User uid gid = User - { id :: uid, - groups :: Vector gid - } - deriving stock (Show, Generic, Eq) - deriving anyclass - ( FromRow, - ToRow, - FromJSON, - ToJSON, - NFData, - ToADTArbitrary - ) - -instance (Arbitrary uid, Arbitrary gid) => Arbitrary (User uid gid) where - arbitrary = genericArbitrary - -- | IPv4 address with some useful instances newtype IPv4 = IPv4 Data.IP.IPv4 deriving stock (Generic) diff --git a/inferno-ml-server/src/Inferno/ML/Server/Types.hs b/inferno-ml-server/src/Inferno/ML/Server/Types.hs index ec2ff2fe..fee50819 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Types.hs @@ -414,9 +414,9 @@ type BridgeInfo = type EvaluationInfo = Types.EvaluationInfo (EntityId UId) (EntityId GId) PID -type Model = Types.Model (EntityId UId) (EntityId GId) +type Model = Types.Model (EntityId GId) -type ModelVersion = Types.ModelVersion (EntityId UId) (EntityId GId) Oid +type ModelVersion = Types.ModelVersion (EntityId GId) Oid type InferenceScript = Types.InferenceScript ScriptMetadata (EntityId GId) diff --git a/inferno-ml-server/test/Main.hs b/inferno-ml-server/test/Main.hs index 2de6707b..ef17c04e 100644 --- a/inferno-ml-server/test/Main.hs +++ b/inferno-ml-server/test/Main.hs @@ -107,7 +107,6 @@ mkDbSpec env = Hspec.describe "Database" $ do | Just (model, mversion) <- v ^? _head -> do view #name model `Hspec.shouldBe` "mnist" view (#version . to showVersion) mversion `Hspec.shouldBe` "v1" - view #user model `Hspec.shouldBe` Nothing | otherwise -> Hspec.expectationFailure "No models were retrieved" Hspec.it "gets model size and contents" $ do diff --git a/nix/default.nix b/nix/default.nix index d59407f3..dd3db627 100644 --- a/nix/default.nix +++ b/nix/default.nix @@ -86,8 +86,11 @@ pkgs.haskell-nix.cabalProject { ); }; }; - buildInputs = [ config.treefmt.build.wrapper ] - ++ builtins.attrValues config.treefmt.build.programs; + buildInputs = [ + pkgs.postgresql + config.treefmt.build.wrapper + ] + ++ builtins.attrValues config.treefmt.build.programs; shellHook = let setpath = lib.optionalString cudaSupport diff --git a/nix/inferno-ml/migrations/v1-create-tables.sql b/nix/inferno-ml/migrations/v1-create-tables.sql index a73d1791..27a85135 100644 --- a/nix/inferno-ml/migrations/v1-create-tables.sql +++ b/nix/inferno-ml/migrations/v1-create-tables.sql @@ -16,30 +16,23 @@ create extension lo; -- caching models, etc... If the field is not null, then the entity has been -- "deleted" and cannot be used any longer -create table if not exists users - ( -- Note: this is the bson object ID represented as an integer - id integer primary key - -- Also a list of bson object IDs. This determines model access (see below) - , groups integer[] not null - ); - create table if not exists models ( id serial primary key , name text not null - -- Represented as a map from group IDs to model permissions (read or write), - -- serialized to JSON. This is a bit more flexible than using an `hstore` and - -- might allow us to include a more complex structure in the future more - -- easily - , permissions jsonb not null - , "user" integer references users (id) + , gid bigint not null + , visibility jsonb + -- May be missing, if there is no model version yet + , updated timestamptz -- See note above , terminated timestamptz - , unique (name, "user") + , unique (name, gid) ); create table if not exists mversions ( id serial primary key , model integer references models (id) + -- Short, high-level model description + , description text not null -- Model card (description and metadata) serialized as JSON , card jsonb not null -- The model contents are not stored directly because it might exceed @@ -84,7 +77,7 @@ create table if not exists params , resolution integer not null -- See note above , terminated timestamptz - , "user" integer references users (id) + , uid bigint not null ); -- Execution info for inference evaluation diff --git a/nix/inferno-ml/tests/server.nix b/nix/inferno-ml/tests/server.nix index 77948f50..823d6320 100644 --- a/nix/inferno-ml/tests/server.nix +++ b/nix/inferno-ml/tests/server.nix @@ -28,37 +28,34 @@ pkgs.nixosTest { text = let card = builtins.toJSON { - description.summary = "A model"; + summary.summary = "A model"; metadata = { }; }; - # Note that the nested list is how Aeson will decode/encode - # a map, which is the Haskell value for this field - permissions = builtins.toJSON [ - [ "o000000000000000000000001" "read" ] - ]; in '' psql -U inferno -d inferno << EOF INSERT INTO models ( name - , permissions - , "user" + , gid + , visibility ) VALUES ( 'mnist' - , '${permissions}'::jsonb - , NULL + , 1::bigint + , '"VCObjectPublic"'::jsonb ); \lo_import ${./models/mnist.ts.pt} INSERT INTO mversions ( model + , description , card , contents , version ) VALUES ( 1 + , 'My first model' , '${card}'::jsonb , :LASTOID , 'v1' @@ -67,18 +64,6 @@ pkgs.nixosTest { ''; } ) - ( - pkgs.writeShellApplication { - name = "insert-user"; - runtimeInputs = [ pkgs.postgresql ]; - text = '' - psql -U inferno -d inferno << EOF - INSERT INTO "users" (id, groups) - VALUES (0, '{1}'); - EOF - ''; - } - ) ( pkgs.writeShellApplication { name = "parse-scripts-and-save-params"; @@ -214,7 +199,6 @@ pkgs.nixosTest { node.succeed( 'psql -U inferno -d inferno -f ${../migrations/v1-create-tables.sql}' ) - node.succeed('insert-user') node.succeed('insert-mnist-model') node.succeed('sudo -HE -u inferno parse-scripts-and-save-params')