Skip to content

Commit

Permalink
Merge pull request #1410 from cornellius-gp/multitask-dgp-example
Browse files Browse the repository at this point in the history
Add multitask DGP example.
  • Loading branch information
gpleiss authored Jan 19, 2021
2 parents 2b3014a + 6982cbf commit f1ae39d
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.7.0"
}
},
"nbformat": 4,
Expand Down
307 changes: 307 additions & 0 deletions examples/05_Deep_Gaussian_Processes/DGP_Multitask_Regression.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
"import torch\n",
"import tqdm\n",
"import gpytorch\n",
"from torch.nn import Linear\n",
"from gpytorch.means import ConstantMean, LinearMean\n",
"from gpytorch.kernels import RBFKernel, ScaleKernel\n",
"from gpytorch.variational import VariationalStrategy, CholeskyVariationalDistribution\n",
Expand Down Expand Up @@ -169,10 +168,8 @@
" batch_shape=batch_shape, ard_num_dims=None\n",
" )\n",
"\n",
" self.linear_layer = Linear(input_dims, 1)\n",
"\n",
" def forward(self, x):\n",
" mean_x = self.mean_module(x) # self.linear_layer(x).squeeze(-1)\n",
" mean_x = self.mean_module(x)\n",
" covar_x = self.covar_module(x)\n",
" return MultivariateNormal(mean_x, covar_x)\n",
"\n",
Expand Down Expand Up @@ -544,7 +541,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.7.0"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"from gpytorch.kernels import ScaleKernel, MaternKernel\n",
"from gpytorch.variational import VariationalStrategy, BatchDecoupledVariationalStrategy\n",
"from gpytorch.variational import MeanFieldVariationalDistribution\n",
"from gpytorch.models.deep_gps import DeepGP\n",
"from gpytorch.models.deep_gps.dspp import DSPPLayer, DSPP\n",
"import gpytorch.settings as settings"
]
Expand Down Expand Up @@ -487,7 +486,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
"version": "3.7.0"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions examples/05_Deep_Gaussian_Processes/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

Deep_Gaussian_Processes.ipynb
Deep_Sigma_Point_Processes.ipynb
DGP_Multitask_Regression.ipynb

0 comments on commit f1ae39d

Please sign in to comment.