From 838770a367044eaf26e9c9ad90c6b5aaa0013d03 Mon Sep 17 00:00:00 2001 From: Yann Bizeul Date: Wed, 4 Sep 2024 15:42:09 +0200 Subject: [PATCH] OIDC implementation (#39) Implement OID authentication --- .vscode/launch.json | 2 +- html/src/APIClient.ts | 31 ++++- html/src/App.tsx | 21 ++- html/src/Components/Haffix.tsx | 5 +- html/src/LoggedInContext.ts | 10 +- html/src/Pages/Login.tsx | 23 +++- html/src/Pages/SharesPage.tsx | 13 +- html/vite.config.ts | 3 + hupload/.gitignore | 3 +- hupload/go.mod | 5 + hupload/go.sum | 24 ++++ hupload/handlers.go | 33 +++-- hupload/handlers_test.go | 10 +- hupload/internal/config/config.go | 13 ++ hupload/pkg/apiws/apiws.go | 63 +++++++-- hupload/pkg/apiws/apiws_test.go | 26 +++- .../apiws/authentication/authentication.go | 19 ++- hupload/pkg/apiws/authentication/default.go | 22 ++- .../pkg/apiws/authentication/default_test.go | 8 +- hupload/pkg/apiws/authentication/errors.go | 5 + hupload/pkg/apiws/authentication/file.go | 29 +++- hupload/pkg/apiws/authentication/file_test.go | 35 ++--- hupload/pkg/apiws/authentication/oidc.go | 129 ++++++++++++++++++ hupload/pkg/apiws/middleware/auth/auth.go | 101 +++++++------- .../pkg/apiws/middleware/auth/auth_test.go | 30 ++-- hupload/pkg/apiws/middleware/auth/basic.go | 33 +++-- .../pkg/apiws/middleware/auth/basic_test.go | 35 +++-- hupload/pkg/apiws/middleware/auth/jwt.go | 28 ++-- hupload/pkg/apiws/middleware/auth/jwt_test.go | 28 ++-- hupload/pkg/apiws/middleware/auth/oidc.go | 45 ++++++ hupload/pkg/apiws/middleware/auth/open.go | 2 +- .../pkg/apiws/middleware/auth/open_test.go | 40 +++--- hupload/server.go | 98 ++++++++----- 33 files changed, 709 insertions(+), 263 deletions(-) create mode 100644 hupload/pkg/apiws/authentication/oidc.go create mode 100644 hupload/pkg/apiws/middleware/auth/oidc.go diff --git a/.vscode/launch.json b/.vscode/launch.json index 80f47e3..f468aab 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,7 +12,7 @@ "program": "${workspaceFolder}/hupload", "env": { "JWT_SECRET": "9e1fada26b20ddc5ce812cafb8d2cada" - } + }, }, { "name": "Launch hupload (Demo)", diff --git a/html/src/APIClient.ts b/html/src/APIClient.ts index 7a1ce8d..11791e1 100644 --- a/html/src/APIClient.ts +++ b/html/src/APIClient.ts @@ -48,19 +48,20 @@ export class APIClient { console.log("logout") return } - - login(path: string, user : string, password : string) { - return new Promise((resolve, reject) => { - this.request({ + + login(path: string, user? : string, password? : string) { + return new Promise((resolve, reject) => { + axios({ url: this.baseURL + path, method: 'POST', - auth: { + maxRedirects: 0, + auth: (user && password)?{ username: user, password: password - } + }:undefined }) .then((result) => { - resolve(result) + resolve(result.data) }) .catch (e => { reject(e) @@ -296,6 +297,11 @@ export class APIClient { } } +export interface Auth { + showLoginForm: boolean + loginUrl: string + } + class HuploadClient extends APIClient { constructor() { super('/api/v1') @@ -309,6 +315,17 @@ class HuploadClient extends APIClient { document.cookie = "X-Token" +'=; Path=/; Expires=Thu, 01 Jan 1970 00:00:01 GMT;'; document.cookie = "X-Token-Refresh" +'=; Path=/; Expires=Thu, 01 Jan 1970 00:00:01 GMT;'; } + auth() { + return new Promise((resolve, reject) => { + this.request({ + url: '/auth', + }) + .then((result) => { + resolve(result) + }) + .catch(reject) + }) + } } export const H = new HuploadClient() diff --git a/html/src/App.tsx b/html/src/App.tsx index 7500618..0dc7177 100644 --- a/html/src/App.tsx +++ b/html/src/App.tsx @@ -5,29 +5,28 @@ import { useEffect, useState } from "react"; import { Container, MantineProvider } from "@mantine/core"; import { BrowserRouter, Route, Routes } from "react-router-dom"; -import { H } from "./APIClient"; +//import { H } from "./APIClient"; import { SharePage, Login, SharesPage } from "@/Pages"; -import { LoggedInContext } from "@/LoggedInContext"; +import { LoggedInContext, LoggedIn } from "@/LoggedInContext"; import { VersionComponent } from "@/Components"; import { Haffix } from "./Components/Haffix"; +import { H } from "./APIClient"; +//import { AxiosResponse } from "axios"; + -// Logged in user is passed to the context -interface LoggedIn { - user: string -} export default function App() { // Component state - const [loggedIn, setLoggedIn ] = useState(null) + const [loggedIn, setLoggedIn ] = useState(null) // Check with server current logged in state // This is typically executed once when Hupload is loaded // State is updated later on login page or logout button useEffect(() => { - H.post('/login').then((r) => { + H.login('/login').then((r) => { const l = r as LoggedIn - setLoggedIn(l.user) + setLoggedIn(l) }) .catch((e) => { setLoggedIn(null) @@ -40,9 +39,9 @@ export default function App() { - + }/> - {loggedIn&&}} /> + {loggedIn&&}} /> {loggedIn&&}} /> diff --git a/html/src/Components/Haffix.tsx b/html/src/Components/Haffix.tsx index 54f385a..ee0769b 100644 --- a/html/src/Components/Haffix.tsx +++ b/html/src/Components/Haffix.tsx @@ -2,9 +2,12 @@ import { ActionIcon, Affix, Tooltip } from "@mantine/core"; import { IconArrowLeft, IconLogout } from "@tabler/icons-react"; import { Link, useMatch } from "react-router-dom"; import { H } from "@/APIClient"; +import { useLoggedInContext } from "@/LoggedInContext"; export function Haffix() { const location = useMatch('/:share') + const { setLoggedIn } = useLoggedInContext() + return( <> @@ -12,7 +15,7 @@ export function Haffix() { - { H.logoutNow(); window.location.href='/'}}> + { setLoggedIn(null);H.logoutNow(); window.location.href='/'}}> diff --git a/html/src/LoggedInContext.ts b/html/src/LoggedInContext.ts index f3c8de6..4835ad6 100644 --- a/html/src/LoggedInContext.ts +++ b/html/src/LoggedInContext.ts @@ -1,8 +1,14 @@ import { createContext, useContext } from "react"; +// Logged in user is passed to the context +export interface LoggedIn { + user: string + loginPage: string +} + interface LoggedInContextValue { - loggedIn: string | null; - setLoggedIn: React.Dispatch>; + loggedIn: LoggedIn | null; + setLoggedIn: React.Dispatch>; } export const LoggedInContext = createContext(undefined); diff --git a/html/src/Pages/Login.tsx b/html/src/Pages/Login.tsx index 3b9a1e1..1fad01f 100644 --- a/html/src/Pages/Login.tsx +++ b/html/src/Pages/Login.tsx @@ -1,16 +1,18 @@ import { Alert, Button, Container, FocusTrap, Paper, PasswordInput, TextInput } from "@mantine/core"; -import { useState } from "react"; -import { APIServerError, H } from "../APIClient"; +import { useEffect, useState } from "react"; +import { APIServerError, Auth, H } from "../APIClient"; import { IconExclamationCircle } from "@tabler/icons-react"; import { useNavigate } from "react-router-dom"; import { useLoggedInContext } from "../LoggedInContext"; + export function Login() { // Initialize States const [username, setUsername] = useState("") const [password, setPassword] = useState("") const [error, setError] = useState() + const [showLoginForm, setShowLoginForm] = useState(undefined) // Initialize hooks const navigate = useNavigate(); @@ -25,7 +27,7 @@ export function Login() { setError(undefined) navigate("/shares") if (setLoggedIn !== null) { - setLoggedIn(username) + setLoggedIn({user: username, loginPage: '/login'}) } }) .catch(e => { @@ -34,6 +36,21 @@ export function Login() { } } + useEffect(() => { + H.auth() + .then((r) => { + const resp = r as Auth + setShowLoginForm(resp.showLoginForm) + if (resp.loginUrl !== document.location.pathname) { + window.location.href = resp.loginUrl + } + }) + },[navigate]) + + if (showLoginForm !== true) { + return + } + return ( diff --git a/html/src/Pages/SharesPage.tsx b/html/src/Pages/SharesPage.tsx index b69a67b..b6586b9 100644 --- a/html/src/Pages/SharesPage.tsx +++ b/html/src/Pages/SharesPage.tsx @@ -9,6 +9,7 @@ import { IconChevronDown, IconMoodSad } from "@tabler/icons-react"; import { AxiosError } from "axios"; import { ShareEditor } from "@/Components/ShareEditor"; import { useMediaQuery } from "@mantine/hooks"; + import classes from './SharesPage.module.css'; export function SharesPage(props: {owner: string|null}) { @@ -48,20 +49,18 @@ export function SharesPage(props: {owner: string|null}) { }) .catch((e) => { console.log(e) - setError(e) + if (e.response?.status === 401) { + navigate('/') + return + } }) - },[]) + },[navigate]) useEffect(() => { updateShares() },[updateShares]) if (error) { - if (error.response?.status === 401) { - navigate('/') - return - } - return (
diff --git a/html/vite.config.ts b/html/vite.config.ts index 2ab0eb4..4dd6ef8 100644 --- a/html/vite.config.ts +++ b/html/vite.config.ts @@ -9,6 +9,9 @@ export default defineConfig({ proxy: { '/api': 'http://127.0.0.1:8080/', '/d/': 'http://127.0.0.1:8080/', + '/login': 'http://127.0.0.1:8080/', + '/oidc': 'http://127.0.0.1:8080/', + '/auth': 'http://127.0.0.1:8080/', } }, }) diff --git a/hupload/.gitignore b/hupload/.gitignore index 339febc..47ce053 100644 --- a/hupload/.gitignore +++ b/hupload/.gitignore @@ -1,3 +1,4 @@ admin-ui __* -local \ No newline at end of file +local +hupload \ No newline at end of file diff --git a/hupload/go.mod b/hupload/go.mod index 80f1d79..3fac329 100644 --- a/hupload/go.mod +++ b/hupload/go.mod @@ -11,7 +11,9 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.27.30 github.com/aws/aws-sdk-go-v2/credentials v1.17.29 github.com/aws/aws-sdk-go-v2/service/s3 v1.60.1 + github.com/coreos/go-oidc v2.2.1+incompatible golang.org/x/crypto v0.25.0 + golang.org/x/oauth2 v0.21.0 ) require ( @@ -30,4 +32,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 // indirect github.com/aws/smithy-go v1.20.4 // indirect + github.com/pquerna/cachecontrol v0.2.0 // indirect + github.com/stretchr/testify v1.8.2 // indirect + gopkg.in/square/go-jose.v2 v2.6.0 // indirect ) diff --git a/hupload/go.sum b/hupload/go.sum index e6a7c32..deff3d9 100644 --- a/hupload/go.sum +++ b/hupload/go.sum @@ -34,11 +34,35 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 h1:OMsEmCyz2i89XwRwPouAJvhj81wI github.com/aws/aws-sdk-go-v2/service/sts v1.30.5/go.mod h1:vmSqFK+BVIwVpDAGZB3CoCXHzurt4qBE8lf+I/kRTh0= github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/coreos/go-oidc v2.2.1+incompatible h1:mh48q/BqXqgjVHpy2ZY7WnWAbenxRjsz9N1i1YxjHAk= +github.com/coreos/go-oidc v2.2.1+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pquerna/cachecontrol v0.2.0 h1:vBXSNuE5MYP9IJ5kjsdo8uq+w41jSPgvba2DEnkRx9k= +github.com/pquerna/cachecontrol v0.2.0/go.mod h1:NrUG3Z7Rdu85UNR3vm7SOsl1nFIeSiQnrHV5K9mBcUI= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= +golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= +gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/hupload/handlers.go b/hupload/handlers.go index eddb068..038e982 100644 --- a/hupload/handlers.go +++ b/hupload/handlers.go @@ -23,7 +23,7 @@ import ( // postShare creates a new share with a randomly generate name func (h *Hupload) postShare(w http.ResponseWriter, r *http.Request) { - user := auth.UserForRequest(r) + user, _ := auth.AuthForRequest(r) // This should never happen as authentication is checked before in the // middleware @@ -67,7 +67,7 @@ func (h *Hupload) postShare(w http.ResponseWriter, r *http.Request) { // patchShare updates an existing share func (h *Hupload) patchShare(w http.ResponseWriter, r *http.Request) { - user := auth.UserForRequest(r) + user, _ := auth.AuthForRequest(r) // This should never happen as authentication is checked before in the // middleware @@ -117,8 +117,9 @@ func (h *Hupload) postItem(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusInternalServerError, err.Error()) return } + user, _ := auth.AuthForRequest(r) - if auth.UserForRequest(r) == "" && (share.Options.Exposure != "both" && share.Options.Exposure != "upload") { + if user == "" && (share.Options.Exposure != "both" && share.Options.Exposure != "upload") { writeError(w, http.StatusUnauthorized, "unauthorized") return } @@ -184,8 +185,9 @@ func (h *Hupload) deleteItem(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusInternalServerError, err.Error()) return } + user, _ := auth.AuthForRequest(r) - if auth.UserForRequest(r) == "" && (share.Options.Exposure != "both" && share.Options.Exposure != "upload") { + if user == "" && (share.Options.Exposure != "both" && share.Options.Exposure != "upload") { writeError(w, http.StatusUnauthorized, "unauthorized") return } @@ -216,8 +218,9 @@ func (h *Hupload) getShares(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusInternalServerError, err.Error()) return } + user, _ := auth.AuthForRequest(r) - if auth.UserForRequest(r) == "" { + if user == "" { writeSuccessJSON(w, storage.PublicShares(shares)) } else { writeSuccessJSON(w, shares) @@ -241,13 +244,14 @@ func (h *Hupload) getShare(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusInternalServerError, err.Error()) return } + user, _ := auth.AuthForRequest(r) - if auth.UserForRequest(r) == "" && !share.IsValid() { + if user == "" && !share.IsValid() { writeError(w, http.StatusGone, "Share expired") return } - if auth.UserForRequest(r) == "" { + if user == "" { writeSuccessJSON(w, share.PublicShare()) } else { writeSuccessJSON(w, share) @@ -270,8 +274,9 @@ func (h *Hupload) getShareItems(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusInternalServerError, err.Error()) return } + user, _ := auth.AuthForRequest(r) - if auth.UserForRequest(r) == "" && !share.IsValid() { + if user == "" && !share.IsValid() { writeError(w, http.StatusGone, "Share expired") return } @@ -324,7 +329,9 @@ func (h *Hupload) getItem(w http.ResponseWriter, r *http.Request) { return } - if auth.UserForRequest(r) == "" && (share.Options.Exposure != "both" && share.Options.Exposure != "download") { + user, _ := auth.AuthForRequest(r) + + if user == "" && (share.Options.Exposure != "both" && share.Options.Exposure != "download") { writeError(w, http.StatusUnauthorized, "unauthorized") return } @@ -366,10 +373,14 @@ func (h *Hupload) getItem(w http.ResponseWriter, r *http.Request) { // postLogin returns the user name for the current session func (h *Hupload) postLogin(w http.ResponseWriter, r *http.Request) { + user, _ := auth.AuthForRequest(r) + u := struct { - User string `json:"user"` + User string `json:"user"` + LoginPage string `json:"loginPage"` }{ - User: auth.UserForRequest(r), + User: user, + LoginPage: h.Config.Authentication.LoginURL(), } writeSuccessJSON(w, u) } diff --git a/hupload/handlers_test.go b/hupload/handlers_test.go index a601e43..332c554 100644 --- a/hupload/handlers_test.go +++ b/hupload/handlers_test.go @@ -67,7 +67,7 @@ func getHupload(t *testing.T, cfg *config.Config) *Hupload { func makeShare(t *testing.T, h *Hupload, name string, params storage.Options) *storage.Share { share, err := h.Config.Storage.CreateShare(context.Background(), name, "admin", params) if err != nil { - t.Fatal(err) + t.Error(err) } return share } @@ -208,7 +208,9 @@ func TestCreateShare(t *testing.T) { t.Run("Create a share with same name should fail", func(t *testing.T) { makeShare(t, h, "samename", storage.Options{}) - + t.Cleanup(func() { + _ = h.Config.Storage.DeleteShare(context.Background(), "samename") + }) var ( req *http.Request w *httptest.ResponseRecorder @@ -228,10 +230,6 @@ func TestCreateShare(t *testing.T) { t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code) return } - - t.Cleanup(func() { - _ = h.Config.Storage.DeleteShare(context.Background(), "samename") - }) }) t.Run("Create a share with specific name should succeed", func(t *testing.T) { diff --git a/hupload/internal/config/config.go b/hupload/internal/config/config.go index 21dee7a..7a0443d 100644 --- a/hupload/internal/config/config.go +++ b/hupload/internal/config/config.go @@ -201,6 +201,19 @@ func (c *Config) authentication() (authentication.Authentication, error) { return nil, err } return authentication.NewAuthenticationFile(options) + case "oidc": + var options authentication.AuthenticationOIDCConfig + + b, err := yaml.Marshal(a.Options) + if err != nil { + return nil, err + } + + err = yaml.Unmarshal(b, &options) + if err != nil { + return nil, err + } + return authentication.NewAuthenticationOIDC(options) case "default": return authentication.NewAuthenticationDefault(), nil } diff --git a/hupload/pkg/apiws/apiws.go b/hupload/pkg/apiws/apiws.go index 6938196..a5128ac 100644 --- a/hupload/pkg/apiws/apiws.go +++ b/hupload/pkg/apiws/apiws.go @@ -6,11 +6,13 @@ import ( "io/fs" "log/slog" "net/http" + "os" "path" "github.com/ybizeul/hupload/pkg/apiws/authentication" "github.com/ybizeul/hupload/pkg/apiws/middleware/auth" logger "github.com/ybizeul/hupload/pkg/apiws/middleware/log" + "gopkg.in/square/go-jose.v2/json" ) type APIWS struct { @@ -81,6 +83,17 @@ func New(staticUI fs.FS, templateData any) (*APIWS, error) { }) } + result.AddPublicRoute("GET /auth", nil, func(w http.ResponseWriter, r *http.Request) { + response := struct { + ShowLoginForm bool `json:"showLoginForm"` + LoginURL string `json:"loginUrl"` + }{ + ShowLoginForm: result.Authentication.ShowLoginForm(), + LoginURL: result.Authentication.LoginURL(), + } + _ = json.NewEncoder(w).Encode(response) + }) + return result, nil } @@ -92,18 +105,35 @@ func (a *APIWS) SetAuthentication(b authentication.Authentication) { // AddRoute adds a new route to the API Web Server. pattern is the URL pattern // to match. authenticators is a list of Authenticator to use to authenticate // the request. handlerFunc is the function to call when the route is matched. -func (a *APIWS) AddRoute(pattern string, authenticators []auth.AuthMiddleware, handlerFunc func(w http.ResponseWriter, r *http.Request)) { - if authenticators == nil { - a.Mux.HandleFunc(pattern, handlerFunc) +func (a *APIWS) AddRoute(pattern string, authenticator auth.AuthMiddleware, handlerFunc func(w http.ResponseWriter, r *http.Request)) { + j := auth.JWTAuthMiddleware{ + HMACSecret: os.Getenv("JWT_SECRET"), + } + c := auth.ConfirmAuthenticator{Realm: "Hupload"} + a.Mux.Handle(pattern, + authenticator.Middleware( + j.Middleware( + c.Middleware(http.HandlerFunc(handlerFunc))))) +} + +func (a *APIWS) AddPublicRoute(pattern string, authenticator auth.AuthMiddleware, handlerFunc func(w http.ResponseWriter, r *http.Request)) { + j := auth.JWTAuthMiddleware{ + HMACSecret: os.Getenv("JWT_SECRET"), + } + c := auth.ConfirmAuthenticator{Realm: "Hupload"} + o := auth.OpenAuthMiddleware{} + + if authenticator == nil { + a.Mux.Handle(pattern, + j.Middleware( + o.Middleware( + c.Middleware(http.HandlerFunc(handlerFunc))))) } else { - var h http.Handler - h = http.HandlerFunc(handlerFunc) - c := auth.ConfirmAuthenticator{Realm: "Hupload"} - h = c.Middleware(h) - for i := range authenticators { - h = authenticators[len(authenticators)-1-i].Middleware(h) - } - a.Mux.Handle(pattern, h) + a.Mux.Handle(pattern, + authenticator.Middleware( + j.Middleware( + o.Middleware( + c.Middleware(http.HandlerFunc(handlerFunc)))))) } } @@ -111,6 +141,17 @@ func (a *APIWS) AddRoute(pattern string, authenticators []auth.AuthMiddleware, h func (a *APIWS) Start() { slog.Info(fmt.Sprintf("Starting web service on port %d", a.HTTPPort)) + // Check if we have a callback function for this authentication + if _, ok := a.Authentication.CallbackFunc(nil); ok { + // If there is, define action to redirect to "/shares" + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/shares", http.StatusFound) + }) + m := auth.NewJWTAuthMiddleware(os.Getenv("JWT_SECRET")) + f, _ := a.Authentication.CallbackFunc(m.Middleware(handler)) + a.Mux.HandleFunc("GET /oidc", f) + } + err := http.ListenAndServe(fmt.Sprintf(":%d", a.HTTPPort), logger.NewLogger(a.Mux)) if err != nil { slog.Error("unable to start http server", slog.String("error", err.Error())) diff --git a/hupload/pkg/apiws/apiws_test.go b/hupload/pkg/apiws/apiws_test.go index 0ff7eaf..3ba4da1 100644 --- a/hupload/pkg/apiws/apiws_test.go +++ b/hupload/pkg/apiws/apiws_test.go @@ -2,6 +2,7 @@ package apiws import ( "encoding/json" + "errors" "io/fs" "log/slog" "net/http" @@ -44,7 +45,7 @@ func makeAPI(staticUI fs.FS, templateData any) *APIWS { func TestSimpleAPI(t *testing.T) { api := makeAPI(nil, nil) - api.AddRoute("GET /", nil, func(w http.ResponseWriter, r *http.Request) { + api.AddPublicRoute("GET /", nil, func(w http.ResponseWriter, r *http.Request) { writeSuccessJSON(w, map[string]string{"status": "ok"}) }) @@ -62,8 +63,25 @@ type testAuth struct { Password string } -func (a *testAuth) AuthenticateUser(username, password string) (bool, error) { - return a.Username == username && a.Password == password, nil +func (a *testAuth) AuthenticateRequest(w http.ResponseWriter, r *http.Request) error { + username, password, ok := r.BasicAuth() + if !ok { + return errors.New("No basic auth") + } + if a.Username == username && a.Password == password { + return nil + } + return errors.New("bad credentials") +} + +func (o *testAuth) CallbackFunc(http.Handler) (func(w http.ResponseWriter, r *http.Request), bool) { + return nil, false +} +func (o *testAuth) ShowLoginForm() bool { + return false +} +func (o *testAuth) LoginURL() string { + return "/" } func TestAuthAPI(t *testing.T) { @@ -76,7 +94,7 @@ func TestAuthAPI(t *testing.T) { api := makeAPI(nil, nil) - api.AddRoute("GET /", []auth.AuthMiddleware{authenticator}, func(w http.ResponseWriter, r *http.Request) { + api.AddRoute("GET /", authenticator, func(w http.ResponseWriter, r *http.Request) { writeSuccessJSON(w, map[string]string{"status": "ok"}) }) diff --git a/hupload/pkg/apiws/authentication/authentication.go b/hupload/pkg/apiws/authentication/authentication.go index 254fa3e..1fb5ee9 100644 --- a/hupload/pkg/apiws/authentication/authentication.go +++ b/hupload/pkg/apiws/authentication/authentication.go @@ -1,11 +1,28 @@ package authentication +import ( + "net/http" +) + type User struct { Username string `yaml:"username"` Password string `yaml:"password"` } +type AuthStatusKeyType string + +var AuthStatusKey AuthStatusKeyType = "AuthStatus" + +type AuthStatus struct { + Authenticated bool + User string + Error error +} + // AuthenticationInterface must be implemented by the authentication backend type Authentication interface { - AuthenticateUser(username, password string) (bool, error) + AuthenticateRequest(w http.ResponseWriter, r *http.Request) error + CallbackFunc(http.Handler) (func(w http.ResponseWriter, r *http.Request), bool) + ShowLoginForm() bool + LoginURL() string } diff --git a/hupload/pkg/apiws/authentication/default.go b/hupload/pkg/apiws/authentication/default.go index 3c527df..a5c9425 100644 --- a/hupload/pkg/apiws/authentication/default.go +++ b/hupload/pkg/apiws/authentication/default.go @@ -4,6 +4,7 @@ import ( "fmt" "log/slog" "math/rand/v2" + "net/http" ) // AuthenticationDefault is the default authentication when none has been found @@ -22,13 +23,28 @@ func NewAuthenticationDefault() *AuthenticationDefault { return r } -func (a *AuthenticationDefault) AuthenticateUser(username, password string) (bool, error) { +func (a *AuthenticationDefault) AuthenticateRequest(w http.ResponseWriter, r *http.Request) error { + username, password, ok := r.BasicAuth() + if !ok { + return ErrAuthenticationMissingCredentials + } if username == "admin" && password == a.Password { - return true, nil + + return nil } - return false, nil + return ErrAuthenticationBadCredentials +} + +func (o *AuthenticationDefault) CallbackFunc(http.Handler) (func(w http.ResponseWriter, r *http.Request), bool) { + return nil, false } +func (o *AuthenticationDefault) ShowLoginForm() bool { + return false +} +func (o *AuthenticationDefault) LoginURL() string { + return "/" +} func generateCode(l int) string { code := "" diff --git a/hupload/pkg/apiws/authentication/default_test.go b/hupload/pkg/apiws/authentication/default_test.go index 29b59e4..614397c 100644 --- a/hupload/pkg/apiws/authentication/default_test.go +++ b/hupload/pkg/apiws/authentication/default_test.go @@ -1,6 +1,7 @@ package authentication import ( + "net/http" "regexp" "testing" ) @@ -15,11 +16,10 @@ func TestDefaultAuthentication(t *testing.T) { t.Errorf("Expected password to be 7 characters long, got %s", p) } - b, err := a.AuthenticateUser("admin", p) + r, _ := http.NewRequest("GET", "http://localhost:8080", nil) + r.SetBasicAuth("admin", p) - if !b { - t.Errorf("Expected true, got false") - } + err := a.AuthenticateRequest(nil, r) if err != nil { t.Errorf("Expected no error, got %v", err) diff --git a/hupload/pkg/apiws/authentication/errors.go b/hupload/pkg/apiws/authentication/errors.go index a7777ec..7f3b020 100644 --- a/hupload/pkg/apiws/authentication/errors.go +++ b/hupload/pkg/apiws/authentication/errors.go @@ -4,3 +4,8 @@ import "errors" var ErrAuthenticationMissingUsersFile = errors.New("missing users file") var ErrAuthenticationInvalidPath = errors.New("invalid path") + +var ErrAuthenticationMissingCredentials = errors.New("no credentials provided in request") +var ErrAuthenticationBadCredentials = errors.New("invalid username or password") + +var ErrAuthenticationRedirect = errors.New("redirect to authenticate") diff --git a/hupload/pkg/apiws/authentication/file.go b/hupload/pkg/apiws/authentication/file.go index 5187300..c205fbc 100644 --- a/hupload/pkg/apiws/authentication/file.go +++ b/hupload/pkg/apiws/authentication/file.go @@ -1,6 +1,7 @@ package authentication import ( + "net/http" "os" "golang.org/x/crypto/bcrypt" @@ -39,7 +40,12 @@ func NewAuthenticationFile(o FileAuthenticationConfig) (*AuthenticationFile, err return &r, nil } -func (a *AuthenticationFile) AuthenticateUser(username, password string) (bool, error) { +func (a *AuthenticationFile) AuthenticateRequest(w http.ResponseWriter, r *http.Request) error { + username, password, ok := r.BasicAuth() + if !ok { + return ErrAuthenticationMissingCredentials + } + // Prepare struct to load users.yaml var users []User @@ -48,14 +54,14 @@ func (a *AuthenticationFile) AuthenticateUser(username, password string) (bool, // Fail if we can't open the file pf, err := os.Open(path) if err != nil { - return false, err + return err } defer pf.Close() // Load users.yml err = yaml.NewDecoder(pf).Decode(&users) if err != nil { - return false, err + return err } // Check if user is in the list @@ -64,9 +70,22 @@ func (a *AuthenticationFile) AuthenticateUser(username, password string) (bool, // Compare password hash err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password)) if err == nil { - return true, nil + return nil } } } - return false, nil + + return ErrAuthenticationBadCredentials +} + +func (o *AuthenticationFile) CallbackFunc(http.Handler) (func(w http.ResponseWriter, r *http.Request), bool) { + return nil, false +} + +func (o *AuthenticationFile) ShowLoginForm() bool { + return true +} + +func (o *AuthenticationFile) LoginURL() string { + return "/" } diff --git a/hupload/pkg/apiws/authentication/file_test.go b/hupload/pkg/apiws/authentication/file_test.go index 19f1b57..e65baae 100644 --- a/hupload/pkg/apiws/authentication/file_test.go +++ b/hupload/pkg/apiws/authentication/file_test.go @@ -2,6 +2,7 @@ package authentication import ( "errors" + "net/http" "testing" ) @@ -16,29 +17,31 @@ func TestAuthentication(t *testing.T) { t.Errorf("Expected no error, got %v", err) } - b, err := a.AuthenticateUser("admin", "hupload") - if !b { - t.Errorf("Expected true, got false") - } + r, _ := http.NewRequest("GET", "http://localhost:8080", nil) + r.SetBasicAuth("admin", "hupload") - if err != nil { - t.Errorf("Expected no error, got %v", err) - } + err = a.AuthenticateRequest(nil, r) - b, err = a.AuthenticateUser("admin", "random") - if b { - t.Errorf("Expected false, got true") - } if err != nil { t.Errorf("Expected no error, got %v", err) } - b, err = a.AuthenticateUser("nonexistant", "random") - if b { - t.Errorf("Expected false, got true") + r, _ = http.NewRequest("GET", "http://localhost:8080", nil) + r.SetBasicAuth("admin", "random") + + err = a.AuthenticateRequest(nil, r) + + if err != ErrAuthenticationBadCredentials { + t.Errorf("Expected ErrAuthenticationBadCredentials, got %v", err) } - if err != nil { - t.Errorf("Expected no error, got %v", err) + + r, _ = http.NewRequest("GET", "http://localhost:8080", nil) + r.SetBasicAuth("nonexistant", "random") + + err = a.AuthenticateRequest(nil, r) + + if err != ErrAuthenticationBadCredentials { + t.Errorf("Expected ErrAuthenticationBadCredentials, got %v", err) } } diff --git a/hupload/pkg/apiws/authentication/oidc.go b/hupload/pkg/apiws/authentication/oidc.go new file mode 100644 index 0000000..f122e53 --- /dev/null +++ b/hupload/pkg/apiws/authentication/oidc.go @@ -0,0 +1,129 @@ +package authentication + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/coreos/go-oidc" + "golang.org/x/oauth2" +) + +type AuthenticationOIDCConfig struct { + ProviderURL string `yaml:"provider_url"` + ClientID string `yaml:"client_id"` + ClientSecret string `yaml:"client_secret"` + RedirectURL string `yaml:"redirect_url"` +} + +type AuthenticationOIDC struct { + Options AuthenticationOIDCConfig + + Provider *oidc.Provider + Config oauth2.Config +} + +func NewAuthenticationOIDC(o AuthenticationOIDCConfig) (*AuthenticationOIDC, error) { + var err error + result := &AuthenticationOIDC{ + Options: o, + } + + result.Provider, err = oidc.NewProvider(context.Background(), result.Options.ProviderURL) + if err != nil { + return nil, err + } + + // Configure an OpenID Connect aware OAuth2 client. + result.Config = oauth2.Config{ + ClientID: result.Options.ClientID, + ClientSecret: result.Options.ClientSecret, + RedirectURL: result.Options.RedirectURL, + + // Discovery returns the OAuth2 endpoints. + Endpoint: result.Provider.Endpoint(), + + // "openid" is a required scope for OpenID Connect flows. + Scopes: []string{oidc.ScopeOpenID, "profile", "email", "offline_access", "preferred_username"}, + } + + return result, nil +} + +func (o *AuthenticationOIDC) AuthenticateRequest(w http.ResponseWriter, r *http.Request) error { + if r.URL.Path == "/login" { + http.Redirect(w, r, o.Config.AuthCodeURL("state"), http.StatusFound) + return ErrAuthenticationRedirect + } + return nil +} + +func (o *AuthenticationOIDC) CallbackFunc(h http.Handler) (func(w http.ResponseWriter, r *http.Request), bool) { + return func(w http.ResponseWriter, r *http.Request) { + var verifier = o.Provider.Verifier(&oidc.Config{ClientID: o.Options.ClientID}) + + // Verify state and errors. + + oauth2Token, err := o.Config.Exchange(r.Context(), r.URL.Query().Get("code")) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(err) + return + } + + // Extract the ID Token from OAuth2 token. + rawIDToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(err) + return + } + + // Parse and verify ID Token payload. + idToken, err := verifier.Verify(r.Context(), rawIDToken) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(err) + return + } + + // Extract custom claims + var claims struct { + Sub string `json:"sub"` + Email string `json:"email"` + Verified bool `json:"email_verified"` + } + + if err := idToken.Claims(&claims); err != nil { + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(err) + return + } + ServeNextAuthenticated(claims.Sub, h, w, r) + }, true +} + +func ServeNextAuthenticated(user string, next http.Handler, w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), AuthStatusKey, AuthStatus{Authenticated: true, User: user}) + next.ServeHTTP(w, r.WithContext(ctx)) + // if user == "" { + // next.ServeHTTP(w, + // r.WithContext( + // context.WithValue( + // r.Context(), + // AuthStatusKey,AuthStatus{Authenticated: true, User: ""}, + // ), + // ), + // ) + // } else { + // ctx := context.WithValue(r.Context(), AuthStatus{Authenticated: true, User: user}) + // next.ServeHTTP(w, r.WithContext(ctx)) + // } +} + +func (o *AuthenticationOIDC) ShowLoginForm() bool { + return false +} +func (o *AuthenticationOIDC) LoginURL() string { + return "/login" +} diff --git a/hupload/pkg/apiws/middleware/auth/auth.go b/hupload/pkg/apiws/middleware/auth/auth.go index ce7d72b..1d0b850 100644 --- a/hupload/pkg/apiws/middleware/auth/auth.go +++ b/hupload/pkg/apiws/middleware/auth/auth.go @@ -8,65 +8,70 @@ import ( "log/slog" "net/http" "strings" + + "github.com/ybizeul/hupload/pkg/apiws/authentication" ) type AuthMiddleware interface { Middleware(http.Handler) http.Handler } -type ContextValue string - -const ( - AuthStatus ContextValue = "Authenticated" - AuthUser ContextValue = "User" - AuthError ContextValue = "Error" - AuthStatusSuccess = "Success" -) - // serveNextAuthenticated adds a passes w and r to next middleware after adding // successful authentication context key/value -func serveNextAuthenticated(user string, next http.Handler, w http.ResponseWriter, r *http.Request) { - if user == "" { - next.ServeHTTP(w, - r.WithContext( - context.WithValue( - r.Context(), - AuthStatus, - AuthStatusSuccess, - ), - ), - ) - } else { - next.ServeHTTP(w, - r.WithContext( - context.WithValue( - context.WithValue( - r.Context(), - AuthStatus, - AuthStatusSuccess), - AuthUser, - user, - ), - ), - ) +func ServeNextAuthenticated(user string, next http.Handler, w http.ResponseWriter, r *http.Request) { + s, ok := r.Context().Value(authentication.AuthStatusKey).(authentication.AuthStatus) + if !ok { + s = authentication.AuthStatus{} + } + if user != "" { + s.User = user } + s.Authenticated = true + ctx := context.WithValue(r.Context(), authentication.AuthStatusKey, s) + next.ServeHTTP(w, r.WithContext(ctx)) + + // if user == "" { + // next.ServeHTTP(w, + // r.WithContext( + // context.WithValue( + // r.Context(), + // AuthStatus, + // AuthStatusSuccess, + // ), + // ), + // ) + // } else { + // next.ServeHTTP(w, + // r.WithContext( + // context.WithValue( + // context.WithValue( + // r.Context(), + // AuthStatus, + // AuthStatusSuccess), + // AuthUser, + // user, + // ), + // ), + // ) + // } } // serveNextError adds a passes w and r to next middleware after adding // failed authentication context key/value // any previously defined err is wrapped around err -func serveNextError(next http.Handler, w http.ResponseWriter, r *http.Request, err error) { +func ServeNextError(next http.Handler, w http.ResponseWriter, r *http.Request, err error) { if err == nil { err = errors.New("unknown error") } - - e, ok := r.Context().Value(AuthError).(error) + s, ok := r.Context().Value(authentication.AuthStatusKey).(authentication.AuthStatus) var c context.Context if ok { - c = context.WithValue(r.Context(), AuthError, errors.Join(err, e)) + s.Error = errors.Join(s.Error, err) + s.Authenticated = false } else { - c = context.WithValue(r.Context(), AuthError, err) + s = authentication.AuthStatus{Error: err} } + c = context.WithValue(r.Context(), authentication.AuthStatusKey, s) next.ServeHTTP(w, r.WithContext(c)) } @@ -76,18 +81,22 @@ type ConfirmAuthenticator struct { func (a *ConfirmAuthenticator) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Context().Value(AuthStatus) == AuthStatusSuccess { + s, ok := r.Context().Value(authentication.AuthStatusKey).(authentication.AuthStatus) + if ok && s.Authenticated { + // if r.URL.Path == "/oidc" { + // http.Redirect(w, r, "/shares", http.StatusFound) + // return + // } next.ServeHTTP(w, r) return } w.Header().Add("WWW-Authenticate", fmt.Sprintf("Bearer realm=\"%s\"", a.Realm)) w.WriteHeader(http.StatusUnauthorized) - e, ok := r.Context().Value(AuthError).(error) - if ok { + if s.Error != nil { errs := struct { Errors []string `json:"errors"` }{ - Errors: strings.Split(e.Error(), "\n"), + Errors: strings.Split(s.Error.Error(), "\n"), } b, _ := json.Marshal(errs) @@ -103,10 +112,10 @@ func (a *ConfirmAuthenticator) Middleware(next http.Handler) http.Handler { }) } -func UserForRequest(r *http.Request) string { - user, ok := r.Context().Value(AuthUser).(string) +func AuthForRequest(r *http.Request) (string, bool) { + s, ok := r.Context().Value(authentication.AuthStatusKey).(authentication.AuthStatus) if !ok { - return "" + return "", false } - return user + return s.User, s.Authenticated } diff --git a/hupload/pkg/apiws/middleware/auth/auth_test.go b/hupload/pkg/apiws/middleware/auth/auth_test.go index a3c0583..d488a59 100644 --- a/hupload/pkg/apiws/middleware/auth/auth_test.go +++ b/hupload/pkg/apiws/middleware/auth/auth_test.go @@ -6,27 +6,27 @@ import ( "net/http" "net/http/httptest" "testing" -) -var fakeError = errors.New("Some Error") + "github.com/ybizeul/hupload/pkg/apiws/authentication" +) func TestServeNextAuthenticated(t *testing.T) { successMiddleware := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - serveNextAuthenticated("user", next, w, r) + ServeNextAuthenticated("user", next, w, r) }) } fn1 := func(w http.ResponseWriter, r *http.Request) { - c := r.Context().Value(AuthError) + s := r.Context().Value(authentication.AuthStatusKey).(authentication.AuthStatus) + c := s.Error if c != nil { t.Errorf("Expected nil, got %v", c.(error)) } - c = r.Context().Value(AuthStatus) - if c != AuthStatusSuccess { + if !s.Authenticated { t.Errorf("Expected AuthStatusSuccess, got %v", c) } - u := r.Context().Value(AuthUser) + u := s.User if u != "user" { t.Errorf("Expected admin, got %v", u) } @@ -39,20 +39,22 @@ func TestServeNextAuthenticated(t *testing.T) { h1.ServeHTTP(nil, req) } +var fakeError = errors.New("Some Error") + func TestServeNextAuthFailed(t *testing.T) { successMiddleware := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - serveNextError(next, w, r, fakeError) + ServeNextError(next, w, r, fakeError) }) } fn1 := func(w http.ResponseWriter, r *http.Request) { - c := r.Context().Value(AuthError) - if c == nil { + s := r.Context().Value(authentication.AuthStatusKey).(authentication.AuthStatus) + if s.Error == nil { t.Errorf("Expected error, got nil") } else { - if !errors.Is(c.(error), fakeError) { - t.Errorf("Expected fakeError, got %v", c) + if !errors.Is(s.Error, fakeError) { + t.Errorf("Expected fakeError, got %v", s.Error) } } } @@ -67,7 +69,7 @@ func TestServeNextAuthFailed(t *testing.T) { func TestConfirmAuthentication(t *testing.T) { successMiddleware := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - serveNextAuthenticated("user", next, w, r) + ServeNextAuthenticated("user", next, w, r) }) } @@ -93,7 +95,7 @@ func TestConfirmAuthentication(t *testing.T) { func TestFailedAuthentication(t *testing.T) { successMiddleware := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - serveNextError(next, w, r, fakeError) + ServeNextError(next, w, r, fakeError) }) } diff --git a/hupload/pkg/apiws/middleware/auth/basic.go b/hupload/pkg/apiws/middleware/auth/basic.go index abc070a..f5a15f4 100644 --- a/hupload/pkg/apiws/middleware/auth/basic.go +++ b/hupload/pkg/apiws/middleware/auth/basic.go @@ -29,29 +29,34 @@ type BasicAuthMiddleware struct { func (a BasicAuthMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if a.Authentication == nil { - serveNextError(next, w, r, errors.New("no authentication backend")) + ServeNextError(next, w, r, errors.New("no authentication backend")) return } - // Collect authentication from request - qUser, qPasswd, ok := r.BasicAuth() + + // If there is no credentials, skip middleware + var qUser string + var ok bool + if qUser, _, ok = r.BasicAuth(); !ok { + next.ServeHTTP(w, r) + } // If authentication has been sent, check credentials - if ok { - b, err := a.Authentication.AuthenticateUser(qUser, qPasswd) - if err != nil { - serveNextError(next, w, r, err) + + err := a.Authentication.AuthenticateRequest(nil, r) + + if err != nil { + if errors.Is(err, authentication.ErrAuthenticationMissingCredentials) { + ServeNextAuthenticated("", next, w, r) return } - if !b { - serveNextError(next, w, r, ErrBasicAuthAuthenticationFailed) - return - } else { - serveNextAuthenticated(qUser, next, w, r) + if errors.Is(err, authentication.ErrAuthenticationBadCredentials) { + ServeNextError(next, w, r, ErrBasicAuthAuthenticationFailed) return } + ServeNextError(next, w, r, err) + return } - // Fail by default - serveNextError(next, w, r, ErrBasicAuthNoCredentials) + ServeNextAuthenticated(qUser, next, w, r) }) } diff --git a/hupload/pkg/apiws/middleware/auth/basic_test.go b/hupload/pkg/apiws/middleware/auth/basic_test.go index b1cb126..abeb87f 100644 --- a/hupload/pkg/apiws/middleware/auth/basic_test.go +++ b/hupload/pkg/apiws/middleware/auth/basic_test.go @@ -24,17 +24,18 @@ func TestBasicAuth(t *testing.T) { } fn1 := func(w http.ResponseWriter, r *http.Request) { - c := r.Context().Value(AuthError) - if c != nil { - t.Errorf("Expected nil, got %v", c.(error)) + s, ok := r.Context().Value(authentication.AuthStatusKey).(authentication.AuthStatus) + if !ok { + t.Errorf("Expected AuthStatus, got nil") } - c = r.Context().Value(AuthStatus) - if c != AuthStatusSuccess { - t.Errorf("Expected AuthStatusSuccess, got %v", c) + if s.Error != nil { + t.Errorf("Expected nil, got %v", s.Error) } - u := r.Context().Value(AuthUser) - if u != "admin" { - t.Errorf("Expected admin, got %v", u) + if s.Authenticated == false { + t.Errorf("Expected Success, got %t", s.Authenticated) + } + if s.User != "admin" { + t.Errorf("Expected admin, got %v", s.User) } } @@ -62,11 +63,14 @@ func TestBasicWrongCredentials(t *testing.T) { } fn1 := func(w http.ResponseWriter, r *http.Request) { - c := r.Context().Value(AuthError) - if c == nil { + s, ok := r.Context().Value(authentication.AuthStatusKey).(authentication.AuthStatus) + if !ok { + t.Errorf("Expected AuthStatus, got nil") + } + if s.Error == nil { t.Errorf("Expected error, got nil") } else { - if !errors.Is(c.(error), ErrBasicAuthAuthenticationFailed) { + if !errors.Is(s.Error, ErrBasicAuthAuthenticationFailed) { t.Errorf("Expected authentication failed, got %v", c) } } @@ -96,8 +100,11 @@ func TestBasicAuthNoCredentials(t *testing.T) { } fn1 := func(w http.ResponseWriter, r *http.Request) { - c := r.Context().Value(AuthError).(error) - if !errors.Is(c, ErrBasicAuthNoCredentials) { + s, ok := r.Context().Value(authentication.AuthStatusKey).(authentication.AuthStatus) + if !ok { + t.Errorf("Expected AuthStatus, got nil") + } + if !errors.Is(s.Error, ErrBasicAuthNoCredentials) { t.Errorf("Expected ErrBasicAuthNoCredentials, got %v", c) } } diff --git a/hupload/pkg/apiws/middleware/auth/jwt.go b/hupload/pkg/apiws/middleware/auth/jwt.go index 83add1f..ef41d32 100644 --- a/hupload/pkg/apiws/middleware/auth/jwt.go +++ b/hupload/pkg/apiws/middleware/auth/jwt.go @@ -38,23 +38,22 @@ func (j JWTAuthMiddleware) Middleware(next http.Handler) http.Handler { shortCookie, _ := r.Cookie("X-Token") longCookie, _ := r.Cookie("X-Token-Refresh") - upstreamUser := UserForRequest(r) - upstreamAuth := r.Context().Value(AuthStatus) == AuthStatusSuccess + upstreamUser, ok := AuthForRequest(r) - if (shortCookie == nil && longCookie == nil) || (upstreamUser != "" && upstreamAuth) { + if (shortCookie == nil && longCookie == nil) || (upstreamUser != "" && ok) { // Check that authentication has been previoulsy approved // If request is already authenticated, generate a JWT token - if upstreamAuth { + if ok { short, long, err := j.generateTokens(upstreamUser) if err != nil { - serveNextError(next, w, r, err) + ServeNextError(next, w, r, err) return } http.SetCookie(w, &http.Cookie{Name: "X-Token", Value: short, Path: "/", Expires: time.Now().Add(shortTokenMinutesExpire)}) http.SetCookie(w, &http.Cookie{Name: "X-Token-Refresh", Value: long, Path: "/", Expires: time.Now().Add(longTokenMinutesExpire)}) - serveNextAuthenticated(upstreamUser, next, w, r) + ServeNextAuthenticated(upstreamUser, next, w, r) return } @@ -62,7 +61,7 @@ func (j JWTAuthMiddleware) Middleware(next http.Handler) http.Handler { http.SetCookie(w, &http.Cookie{Name: "X-Token", Value: "deleted", Path: "/", Expires: time.Unix(0, 0)}) http.SetCookie(w, &http.Cookie{Name: "X-Token-Refresh", Value: "deleted", Path: "/", Expires: time.Unix(0, 0)}) - serveNextError(next, w, r, JWTAuthNoAuthorizationHeader) + ServeNextError(next, w, r, JWTAuthNoAuthorizationHeader) return } @@ -87,39 +86,36 @@ func (j JWTAuthMiddleware) Middleware(next http.Handler) http.Handler { if err != nil { http.SetCookie(w, &http.Cookie{Name: "X-Token", Value: "deleted", Path: "/", Expires: time.Unix(0, 0)}) http.SetCookie(w, &http.Cookie{Name: "X-Token-Refresh", Value: "deleted", Path: "/", Expires: time.Unix(0, 0)}) - serveNextError(next, w, r, fmt.Errorf("Unable to parse token: %w", err)) + ServeNextError(next, w, r, fmt.Errorf("Unable to parse token: %w", err)) return } if !token.Valid { http.SetCookie(w, &http.Cookie{Name: "X-Token", Value: "deleted", Path: "/", Expires: time.Unix(0, 0)}) http.SetCookie(w, &http.Cookie{Name: "X-Token-Refresh", Value: "deleted", Path: "/", Expires: time.Unix(0, 0)}) - serveNextError(next, w, r, errors.New("Invalid token")) + ServeNextError(next, w, r, errors.New("Invalid token")) return } if claims, ok := token.Claims.(jwt.MapClaims); ok { user, ok := claims["sub"].(string) if !ok || user == "" { - serveNextError(next, w, r, JWTAuthNoSubClaim) + ServeNextError(next, w, r, JWTAuthNoSubClaim) return } _, ok = claims["refresh"] if ok { if !ok { - serveNextError(next, w, r, JWTAuthNoSubClaim) + ServeNextError(next, w, r, JWTAuthNoSubClaim) } short, long, err := j.generateTokens(user) if err != nil { - serveNextError(next, w, r, err) + ServeNextError(next, w, r, err) } http.SetCookie(w, &http.Cookie{Name: "X-Token", Value: short, Path: "/", Expires: time.Now().Add(shortTokenMinutesExpire)}) http.SetCookie(w, &http.Cookie{Name: "X-Token-Refresh", Value: long, Path: "/", Expires: time.Now().Add(longTokenMinutesExpire)}) } - serveNextAuthenticated(user, next, w, r) - - // TODO Verify claim content - //fmt.Println(claims["iss"], claims["sub"], claims["exp"]) + ServeNextAuthenticated(user, next, w, r) } else { slog.Error("jwt decoding returned an invalid claim") } diff --git a/hupload/pkg/apiws/middleware/auth/jwt_test.go b/hupload/pkg/apiws/middleware/auth/jwt_test.go index c8ed0b8..b5f314d 100644 --- a/hupload/pkg/apiws/middleware/auth/jwt_test.go +++ b/hupload/pkg/apiws/middleware/auth/jwt_test.go @@ -5,6 +5,8 @@ import ( "net/http" "net/http/httptest" "testing" + + "github.com/ybizeul/hupload/pkg/apiws/authentication" ) func TestJWTAuth(t *testing.T) { @@ -22,17 +24,18 @@ func TestJWTAuth(t *testing.T) { } fn1 := func(w http.ResponseWriter, r *http.Request) { - c := r.Context().Value(AuthError) - if c != nil { - t.Errorf("Expected nil, got %v", c.(error)) + s, ok := r.Context().Value(authentication.AuthStatusKey).(authentication.AuthStatus) + if !ok { + t.Errorf("Expected AuthStatus, got nil") + } + if s.Error != nil { + t.Errorf("Expected nil, got %v", s.Error) } - c = r.Context().Value(AuthStatus) - if c != AuthStatusSuccess { - t.Errorf("Expected AuthStatusSuccess, got %v", c) + if s.Authenticated == false { + t.Errorf("Expected AuthStatusSuccess, got %t", s.Authenticated) } - u := r.Context().Value(AuthUser) - if u != "admin" { - t.Errorf("Expected admin, got %v", u) + if s.User != "admin" { + t.Errorf("Expected admin, got %v", s.User) } _, _ = w.Write([]byte("OK")) } @@ -76,8 +79,11 @@ func TestJWTAuthBadSecret(t *testing.T) { } fn1 := func(w http.ResponseWriter, r *http.Request) { - c := r.Context().Value(AuthError) - if c == nil { + s, ok := r.Context().Value(authentication.AuthStatusKey).(authentication.AuthStatus) + if !ok { + t.Errorf("Expected AuthStatus, got nil") + } + if s.Error == nil { t.Errorf("Expected error, got nil") } w.WriteHeader(http.StatusUnauthorized) diff --git a/hupload/pkg/apiws/middleware/auth/oidc.go b/hupload/pkg/apiws/middleware/auth/oidc.go new file mode 100644 index 0000000..13a753b --- /dev/null +++ b/hupload/pkg/apiws/middleware/auth/oidc.go @@ -0,0 +1,45 @@ +package auth + +import ( + "errors" + "net/http" + + "github.com/ybizeul/hupload/pkg/apiws/authentication" +) + +// BasicAuthenticator uses a password file to authenticate users, like : +// - username: admin +// password: $2y$10$AJEytAoJfc4yQjUS8/cG6eXADlgK/Dt3AvdB0boPJ7EcHofewGQIK +// +// To has a password, you can use htpasswd command : +// +// ❯ htpasswd -bnBC 10 "" hupload +// :$2y$10$AJEytAoJfc4yQjUS8/cG6eXADlgK/Dt3AvdB0boPJ7EcHofewGQIK +// +// and remove the leading `:` from the hash +type OIDCAuthMiddleware struct { + Authentication authentication.Authentication +} + +func (a OIDCAuthMiddleware) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if a.Authentication == nil { + ServeNextError(next, w, r, errors.New("no authentication backend")) + return + } + + // If authentication has been sent, check credentials + + err := a.Authentication.AuthenticateRequest(w, r) + if err != nil { + if err == authentication.ErrAuthenticationRedirect { + return + } + ServeNextError(next, w, r, err) + return + } + + next.ServeHTTP(w, r) + //ServeNextError(next, w, r, authentication.ErrAuthenticationMissingCredentials) + }) +} diff --git a/hupload/pkg/apiws/middleware/auth/open.go b/hupload/pkg/apiws/middleware/auth/open.go index fb0f8f9..60cf9fc 100644 --- a/hupload/pkg/apiws/middleware/auth/open.go +++ b/hupload/pkg/apiws/middleware/auth/open.go @@ -9,6 +9,6 @@ type OpenAuthMiddleware struct { func (a OpenAuthMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - serveNextAuthenticated("", next, w, r) + ServeNextAuthenticated("", next, w, r) }) } diff --git a/hupload/pkg/apiws/middleware/auth/open_test.go b/hupload/pkg/apiws/middleware/auth/open_test.go index 66f92f3..1017faa 100644 --- a/hupload/pkg/apiws/middleware/auth/open_test.go +++ b/hupload/pkg/apiws/middleware/auth/open_test.go @@ -3,6 +3,8 @@ package auth import ( "net/http" "testing" + + "github.com/ybizeul/hupload/pkg/apiws/authentication" ) func TestOpenAuthWithCredentials(t *testing.T) { @@ -10,17 +12,18 @@ func TestOpenAuthWithCredentials(t *testing.T) { m := OpenAuthMiddleware{} fn1 := func(w http.ResponseWriter, r *http.Request) { - c := r.Context().Value(AuthError) - if c != nil { - t.Errorf("Expected nil, got %v", c.(error)) + s, ok := r.Context().Value(authentication.AuthStatusKey).(authentication.AuthStatus) + if !ok { + t.Errorf("Expected AuthStatus, got nil") + } + if s.Error != nil { + t.Errorf("Expected nil, got %v", s.Error) } - c = r.Context().Value(AuthStatus) - if c != AuthStatusSuccess { - t.Errorf("Expected AuthStatusSuccess, got %v", c) + if !s.Authenticated { + t.Errorf("Expected AuthStatusSuccess, got %t", s.Authenticated) } - u := r.Context().Value(AuthUser) - if u != nil { - t.Errorf("Expected nil, got %v", u) + if s.User != "" { + t.Errorf("Expected nil, got %s", s.User) } } @@ -37,17 +40,18 @@ func TestOpenAuthWithoutCredentials(t *testing.T) { m := OpenAuthMiddleware{} fn1 := func(w http.ResponseWriter, r *http.Request) { - c := r.Context().Value(AuthError) - if c != nil { - t.Errorf("Expected nil, got %v", c.(error)) + s, ok := r.Context().Value(authentication.AuthStatusKey).(authentication.AuthStatus) + if !ok { + t.Errorf("Expected AuthStatus, got nil") + } + if s.Error != nil { + t.Errorf("Expected nil, got %v", s.Error) } - c = r.Context().Value(AuthStatus) - if c != AuthStatusSuccess { - t.Errorf("Expected AuthStatusSuccess, got %v", c) + if !s.Authenticated { + t.Errorf("Expected AuthStatusSuccess, got %t", s.Authenticated) } - u := r.Context().Value(AuthUser) - if u != nil { - t.Errorf("Expected nil, got %v", u) + if s.User != "" { + t.Errorf("Expected nil, got %v", s.User) } } diff --git a/hupload/server.go b/hupload/server.go index 0eda1c7..badacc5 100644 --- a/hupload/server.go +++ b/hupload/server.go @@ -11,6 +11,7 @@ import ( "github.com/ybizeul/hupload/internal/config" "github.com/ybizeul/hupload/pkg/apiws" + "github.com/ybizeul/hupload/pkg/apiws/authentication" "github.com/ybizeul/hupload/pkg/apiws/middleware/auth" ) @@ -61,54 +62,81 @@ func (h *Hupload) setup() { api := h.API // Get JWT_SECRET - hmac := os.Getenv("JWT_SECRET") - if len(hmac) == 0 { - hmac = generateRandomString(32) - } + // hmac := os.Getenv("JWT_SECRET") + // if len(hmac) == 0 { + // hmac = generateRandomString(32) + // } // Define authenticators for protected routes - authenticators := []auth.AuthMiddleware{ - auth.BasicAuthMiddleware{ + // authenticators := []auth.AuthMiddleware{ + // auth.BasicAuthMiddleware{ + // Authentication: api.Authentication, + // }, + // auth.JWTAuthMiddleware{ + // HMACSecret: hmac, + // }, + // } + var authenticator auth.AuthMiddleware + switch h.Config.Authentication.(type) { + case *authentication.AuthenticationFile: + authenticator = auth.BasicAuthMiddleware{ Authentication: api.Authentication, - }, - auth.JWTAuthMiddleware{ - HMACSecret: hmac, - }, - } - - authenticatorsOpen := []auth.AuthMiddleware{ - auth.OpenAuthMiddleware{}, - auth.BasicAuthMiddleware{ + } + case *authentication.AuthenticationOIDC: + authenticator = auth.OIDCAuthMiddleware{ Authentication: api.Authentication, - }, - auth.JWTAuthMiddleware{ - HMACSecret: hmac, - }, + } } - + // authenticator := auth.OIDCAuthMiddleware{ + // Authentication: api.Authentication, + // } + + // auth.JWTAuthMiddleware{ + // HMACSecret: hmac, + // }, + + // authenticatorsOpen := []auth.AuthMiddleware{ + // auth.OpenAuthMiddleware{}, + // auth.BasicAuthMiddleware{ + // Authentication: api.Authentication, + // }, + // auth.JWTAuthMiddleware{ + // HMACSecret: hmac, + // }, + // } + // authenticatorsOpen := []auth.AuthMiddleware{ + // auth.OpenAuthMiddleware{}, + // auth.OIDCAuthMiddleware{ + // Authentication: api.Authentication, + // }, + // auth.JWTAuthMiddleware{ + // HMACSecret: hmac, + // }, + // } // Setup routes // Guests can access a share and post new files in it // That's Hupload principle, the security is based on the share name // which is usually a random string. - api.AddRoute("POST /api/v1/shares/{share}/items/{item}", authenticatorsOpen, h.postItem) - api.AddRoute("GET /api/v1/shares/{share}/items", authenticatorsOpen, h.getShareItems) - api.AddRoute("GET /api/v1/shares/{share}", authenticatorsOpen, h.getShare) - api.AddRoute("GET /api/v1/shares/{share}/items/{item}", authenticatorsOpen, h.getItem) - api.AddRoute("GET /d/{share}/{item}", authenticatorsOpen, h.getItem) - api.AddRoute("DELETE /api/v1/shares/{share}/items/{item}", authenticatorsOpen, h.deleteItem) + api.AddPublicRoute("POST /api/v1/shares/{share}/items/{item}", authenticator, h.postItem) + api.AddPublicRoute("GET /api/v1/shares/{share}/items", authenticator, h.getShareItems) + api.AddPublicRoute("GET /api/v1/shares/{share}", authenticator, h.getShare) + api.AddPublicRoute("GET /api/v1/shares/{share}/items/{item}", authenticator, h.getItem) + api.AddPublicRoute("GET /d/{share}/{item}", authenticator, h.getItem) + api.AddPublicRoute("DELETE /api/v1/shares/{share}/items/{item}", authenticator, h.deleteItem) // Protected routes - api.AddRoute("POST /api/v1/login", authenticators, h.postLogin) - api.AddRoute("POST /api/v1/shares", authenticators, h.postShare) - api.AddRoute("POST /api/v1/shares/{share}", authenticators, h.postShare) - api.AddRoute("PATCH /api/v1/shares/{share}", authenticators, h.patchShare) - api.AddRoute("DELETE /api/v1/shares/{share}", authenticators, h.deleteShare) - api.AddRoute("GET /api/v1/shares", authenticators, h.getShares) - api.AddRoute("GET /api/v1/version", authenticators, h.getVersion) - - api.AddRoute("GET /api/v1/*", authenticators, func(w http.ResponseWriter, r *http.Request) { + api.AddRoute("GET /login", authenticator, h.postLogin) + api.AddRoute("POST /api/v1/login", authenticator, h.postLogin) + api.AddRoute("POST /api/v1/shares", authenticator, h.postShare) + api.AddRoute("POST /api/v1/shares/{share}", authenticator, h.postShare) + api.AddRoute("PATCH /api/v1/shares/{share}", authenticator, h.patchShare) + api.AddRoute("DELETE /api/v1/shares/{share}", authenticator, h.deleteShare) + api.AddRoute("GET /api/v1/shares", authenticator, h.getShares) + api.AddRoute("GET /api/v1/version", authenticator, h.getVersion) + + api.AddRoute("GET /api/v1/*", authenticator, func(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "Error") })