diff --git a/cmd/datasetArchiver/main.go b/cmd/datasetArchiver/main.go index 0bc1a5b..cade84c 100644 --- a/cmd/datasetArchiver/main.go +++ b/cmd/datasetArchiver/main.go @@ -101,7 +101,8 @@ func main() { inputdatasetList = args[0:] } - user, _ := datasetUtils.Authenticate(client, APIServer, token, userpass) + auth := &datasetUtils.RealAuthenticator{} + user, _ := datasetUtils.Authenticate(auth, client, APIServer, token, userpass) archivableDatasets := datasetUtils.GetArchivableDatasets(client, APIServer, ownerGroup, inputdatasetList, user["accessToken"]) if len(archivableDatasets) > 0 { diff --git a/cmd/datasetCleaner/main.go b/cmd/datasetCleaner/main.go index 2cc1ffc..75eb5d5 100644 --- a/cmd/datasetCleaner/main.go +++ b/cmd/datasetCleaner/main.go @@ -116,7 +116,8 @@ func main() { return } - user, _ := datasetUtils.Authenticate(client, APIServer, token, userpass) + auth := &datasetUtils.RealAuthenticator{} + user, _ := datasetUtils.Authenticate(auth, client, APIServer, token, userpass) if user["username"] != "archiveManager" { log.Fatalf("You must be archiveManager to be allowed to delete datasets\n") diff --git a/cmd/datasetGetProposal/main.go b/cmd/datasetGetProposal/main.go index 1991ea8..98df8a9 100644 --- a/cmd/datasetGetProposal/main.go +++ b/cmd/datasetGetProposal/main.go @@ -85,7 +85,8 @@ func main() { return } - user, accessGroups := datasetUtils.Authenticate(client, APIServer, token, userpass) + auth := &datasetUtils.RealAuthenticator{} + user, accessGroups := datasetUtils.Authenticate(auth, client, APIServer, token, userpass) proposal := datasetUtils.GetProposal(client, APIServer, ownerGroup, user, accessGroups) // proposal is of type map[string]interface{} diff --git a/cmd/datasetIngestor/main.go b/cmd/datasetIngestor/main.go index 5d2c57d..0d7abd8 100644 --- a/cmd/datasetIngestor/main.go +++ b/cmd/datasetIngestor/main.go @@ -200,7 +200,8 @@ func main() { return } - user, accessGroups := datasetUtils.Authenticate(client, APIServer, token, userpass) + auth := &datasetUtils.RealAuthenticator{} + user, accessGroups := datasetUtils.Authenticate(auth, client, APIServer, token, userpass) /* TODO Add info about policy settings and that autoarchive will take place or not */ diff --git a/cmd/datasetPublishData/main.go b/cmd/datasetPublishData/main.go index b06b1b7..e8828f7 100644 --- a/cmd/datasetPublishData/main.go +++ b/cmd/datasetPublishData/main.go @@ -300,7 +300,8 @@ func createWebpage(urls []string, title string, doi string, datasetDetails []dat // set value in publishedData ============================== - user, _ := datasetUtils.Authenticate(client, APIServer, token, userpass) + auth := &datasetUtils.RealAuthenticator{} + user, _ := datasetUtils.Authenticate(auth, client, APIServer, token, userpass) type PublishedDataPart struct { DownloadLink string `json:"downloadLink"` diff --git a/cmd/datasetPublishDataRetrieve/main.go b/cmd/datasetPublishDataRetrieve/main.go index 49afdfa..34d4bf8 100644 --- a/cmd/datasetPublishDataRetrieve/main.go +++ b/cmd/datasetPublishDataRetrieve/main.go @@ -103,7 +103,8 @@ func main() { return } - user, _ := datasetUtils.Authenticate(client, APIServer, token, userpass) + auth := &datasetUtils.RealAuthenticator{} + user, _ := datasetUtils.Authenticate(auth, client, APIServer, token, userpass) datasetList, _, _ := datasetUtils.GetDatasetsOfPublication(client, APIServer, *publishedDataId) diff --git a/cmd/datasetRetriever/main.go b/cmd/datasetRetriever/main.go index 47791b8..a12b94f 100644 --- a/cmd/datasetRetriever/main.go +++ b/cmd/datasetRetriever/main.go @@ -120,7 +120,8 @@ func main() { return } - user, _ := datasetUtils.Authenticate(client, APIServer, token, userpass) + auth := &datasetUtils.RealAuthenticator{} + user, _ := datasetUtils.Authenticate(auth, client, APIServer, token, userpass) datasetList := datasetUtils.GetAvailableDatasets(user["username"], RSYNCServer, *datasetId) diff --git a/cmd/waitForJobFinished/main.go b/cmd/waitForJobFinished/main.go index b17c069..efb6cc8 100644 --- a/cmd/waitForJobFinished/main.go +++ b/cmd/waitForJobFinished/main.go @@ -84,7 +84,8 @@ func main() { return } - user, _ := datasetUtils.Authenticate(client, APIServer, token, userpass) + auth := &datasetUtils.RealAuthenticator{} + user, _ := datasetUtils.Authenticate(auth, client, APIServer, token, userpass) filter := `{"where":{"id":"` + *jobId + `"}}` diff --git a/datasetUtils/authenticate.go b/datasetUtils/authenticate.go index 513f36a..8eea3f8 100644 --- a/datasetUtils/authenticate.go +++ b/datasetUtils/authenticate.go @@ -8,7 +8,12 @@ import ( "syscall" ) -func Authenticate(httpClient *http.Client, APIServer string, token *string, userpass *string) (map[string]string, []string) { +// Authenticate handles user authentication by prompting the user for their credentials, +// validating these credentials against the authentication server, +// and returning an authentication token if the credentials are valid. +// This token can then be used for authenticated requests to the server. +// If the credentials are not valid, the function returns an error. +func Authenticate(auth Authenticator, httpClient *http.Client, APIServer string, token *string, userpass *string) (map[string]string, []string) { user := make(map[string]string) accessGroups := make([]string, 0) @@ -37,9 +42,9 @@ func Authenticate(httpClient *http.Client, APIServer string, token *string, user } username = strings.Split(*userpass, ":")[0] } - user, accessGroups = AuthenticateUser(httpClient, APIServer, username, password) + user, accessGroups = auth.AuthenticateUser(httpClient, APIServer, username, password) } else { - user, accessGroups = GetUserInfoFromToken(httpClient, APIServer, *token) + user, accessGroups = auth.GetUserInfoFromToken(httpClient, APIServer, *token) // extract password if defined in userpass value u := strings.Split(*userpass, ":") if len(u) == 2 { @@ -48,3 +53,19 @@ func Authenticate(httpClient *http.Client, APIServer string, token *string, user } return user, accessGroups } + +// An interface with the methods so that we can mock them in tests +type Authenticator interface { + AuthenticateUser(httpClient *http.Client, APIServer string, username string, password string) (map[string]string, []string) + GetUserInfoFromToken(httpClient *http.Client, APIServer string, token string) (map[string]string, []string) +} + +type RealAuthenticator struct{} + +func (r *RealAuthenticator) AuthenticateUser(httpClient *http.Client, APIServer string, username string, password string) (map[string]string, []string) { + return AuthenticateUser(httpClient, APIServer, username, password) +} + +func (r *RealAuthenticator) GetUserInfoFromToken(httpClient *http.Client, APIServer string, token string) (map[string]string, []string) { + return GetUserInfoFromToken(httpClient, APIServer, token) +} diff --git a/datasetUtils/authenticate_test.go b/datasetUtils/authenticate_test.go new file mode 100644 index 0000000..9f4f88d --- /dev/null +++ b/datasetUtils/authenticate_test.go @@ -0,0 +1,93 @@ +package datasetUtils + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +// Create a mock implementation of the interface +type MockAuthenticator struct{} + +func (m *MockAuthenticator) AuthenticateUser(httpClient *http.Client, APIServer string, username string, password string) (map[string]string, []string) { + if username == "" && password == "" { + return map[string]string{}, []string{} + } + return map[string]string{"username": "testuser", "password": "testpass"}, []string{"group1", "group2"} +} + +func (m *MockAuthenticator) GetUserInfoFromToken(httpClient *http.Client, APIServer string, token string) (map[string]string, []string) { + return map[string]string{"username": "tokenuser", "password": "tokenpass"}, []string{"group3", "group4"} +} + +func TestAuthenticate(t *testing.T) { + var auth Authenticator = &MockAuthenticator{} + // Mock HTTP server + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Write([]byte(`{"username": "testuser", "accessGroups": ["group1", "group2"]}`)) + })) + defer server.Close() + + // Test cases + tests := []struct { + name string + token string + userpass string + wantUser map[string]string + wantGroup []string + }{ + { + name: "Test with token", + token: "testtoken", + userpass: "", + wantUser: map[string]string{ + "username": "tokenuser", + "password": "tokenpass", + }, + wantGroup: []string{"group3", "group4"}, + }, + { + name: "Test with empty token and userpass", + token: "", + userpass: "", + wantUser: map[string]string{}, + wantGroup: []string{}, + }, + { + name: "Test with empty token and non-empty userpass", + token: "", + userpass: "testuser:testpass", + wantUser: map[string]string{ + "username": "testuser", + "password": "testpass", + }, + wantGroup: []string{"group1", "group2"}, + }, + { + name: "Test with non-empty token and empty userpass", + token: "testtoken", + userpass: "", + wantUser: map[string]string{ + "username": "tokenuser", + "password": "tokenpass", + }, + wantGroup: []string{"group3", "group4"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + httpClient := server.Client() + user, group := Authenticate(auth, httpClient, server.URL, &tt.token, &tt.userpass) + + if !reflect.DeepEqual(user, tt.wantUser) { + t.Errorf("got %v, want %v", user, tt.wantUser) + } + + if !reflect.DeepEqual(group, tt.wantGroup) { + t.Errorf("got %v, want %v", group, tt.wantGroup) + } + }) + } +}