Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test functions in scp.go #56

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions datasetIngestor/scp.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,17 @@ func keyString(k ssh.PublicKey) string {
return k.Type() + " " + base64.StdEncoding.EncodeToString(k.Marshal()) // e.g. "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTY...."
}

/* trustedHostKeyCallback returns a function that serves as a callback for SSH host key verification.
If a trustedKey is provided, the callback will verify if the key from the server matches the trustedKey.
If they don't match, it returns an error.
If no trustedKey is provided, the callback will log a warning that SSH-key verification is not in effect,
but it will not stop the connection.
Parameters:
trustedKey: A string representation of the trusted SSH public key.
Returns:
An ssh.HostKeyCallback function for SSH host key verification. */
func trustedHostKeyCallback(trustedKey string) ssh.HostKeyCallback {
minottic marked this conversation as resolved.
Show resolved Hide resolved
trustedKey = strings.TrimSpace(trustedKey)
if trustedKey == "" {
return func(_ string, _ net.Addr, k ssh.PublicKey) error {
log.Printf("WARNING: SSH-key verification is *NOT* in effect: to fix, add this trustedKey: %q", keyString(k))
Expand All @@ -42,6 +52,7 @@ func trustedHostKeyCallback(trustedKey string) ssh.HostKeyCallback {

return func(_ string, _ net.Addr, k ssh.PublicKey) error {
ks := keyString(k)
ks = strings.TrimSpace(ks)
if trustedKey != ks {
return fmt.Errorf("SSH-key verification: expected %q but got %q", trustedKey, ks)
}
Expand Down Expand Up @@ -137,6 +148,13 @@ func (c *Client) sendRegularFile(w io.Writer, path string, fi os.FileInfo) error
}

// Walk and Send directory
/* walkAndSend recursively walks through the directory specified by 'src',
and sends each file it encounters to the writer 'w'.
If 'src' is a regular file, it sends the file directly.
If 'src' is a directory, it walks through the directory and sends each file it encounters.
It also sends directory change commands (push and pop) to the writer.
If 'c.PreseveTimes' is true, it sends the modification time of each file and directory to the writer.
It returns an error if any operation fails. */
func (c *Client) walkAndSend(w io.Writer, src string) error {
cleanedPath := filepath.Clean(src)

Expand Down
194 changes: 194 additions & 0 deletions datasetIngestor/scp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package datasetIngestor

import (
"testing"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"golang.org/x/crypto/ssh"
"bytes"
"os"
"fmt"
)

// Checks if the function returns an error when the provided key does not match the trusted key.
func TestTrustedHostKeyCallback(t *testing.T) {
// Generate a test key
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}

publicKeyRsa, err := ssh.NewPublicKey(&privateKey.PublicKey)
if err != nil {
t.Fatalf("Failed to generate public key: %v", err)
}

publicKeyBytes := ssh.MarshalAuthorizedKey(publicKeyRsa)
trustedKey := string(publicKeyBytes)

// Create the callback
callback := trustedHostKeyCallback(trustedKey)

// Test the callback with the correct key
err = callback("", nil, publicKeyRsa)
if err != nil {
t.Errorf("Expected no error for correct key, got: %v", err)
}

// Generate a different key for testing mismatch
privateKey2, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}

publicKeyRsa2, err := ssh.NewPublicKey(&privateKey2.PublicKey)
if err != nil {
t.Fatalf("Failed to generate public key: %v", err)
}

// Test the callback with a different key
err = callback("", nil, publicKeyRsa2)
if err == nil {
t.Errorf("Expected error for incorrect key, got nil")
}
}

func generateTestKeyPair(bits int) (string, string, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return "", "", err
}

privateKeyDer := x509.MarshalPKCS1PrivateKey(privateKey)
privateKeyBlock := pem.Block{
Type: "RSA PRIVATE KEY",
Headers: nil,
Bytes: privateKeyDer,
}
privateKeyPem := string(pem.EncodeToMemory(&privateKeyBlock))

publicKeyRsa, err := ssh.NewPublicKey(&privateKey.PublicKey)
if err != nil {
return "", "", err
}
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKeyRsa)
publicKeyPem := string(publicKeyBytes)

return privateKeyPem, publicKeyPem, nil
}

func TestGetSendCommand(t *testing.T) {
client := &Client{
PreseveTimes: true,
Quiet: true,
}

dst := "/path/to/destination"
expected := "scp -rtpq /path/to/destination"

result := client.getSendCommand(dst)

if result != expected {
t.Errorf("getSendCommand() = %s; want %s", result, expected)
}
}

// Checks if the function correctly sends a regular file and handles errors properly
func TestSendRegularFile(t *testing.T) {
client := &Client{
PreseveTimes: true,
Quiet: true,
}

// Create a temporary file for testing
tmpfile, err := os.CreateTemp("", "example")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name()) // clean up

// Write some data to the file
text := []byte("This is a test file.")
if _, err := tmpfile.Write(text); err != nil {
t.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
t.Fatal(err)
}

// Get file info
fi, err := os.Stat(tmpfile.Name())
if err != nil {
t.Fatal(err)
}

// Create a buffer to use as the writer
var buf bytes.Buffer

// Call the function
err = client.sendRegularFile(&buf, tmpfile.Name(), fi)
if err != nil {
t.Errorf("sendRegularFile() error = %v", err)
}

// Check if the file content was written to the buffer
if !bytes.Contains(buf.Bytes(), text) {
t.Errorf("sendRegularFile() did not write file content to writer")
}

// Check if the file permissions were written to the buffer
perm := fmt.Sprintf("C%04o", fi.Mode().Perm())
if !bytes.Contains(buf.Bytes(), []byte(perm)) {
t.Errorf("sendRegularFile() did not write file permissions to writer")
}
}

// Checks if the function returns an error, if the file content was written to the buffer, and if the directory change commands were written to the buffer
func TestWalkAndSend(t *testing.T) {
client := &Client{
PreseveTimes: true,
Quiet: true,
}

// Create a temporary directory for testing
tmpDir, err := os.MkdirTemp("", "example")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir) // clean up

// Create a file in the temporary directory
tmpfile, err := os.CreateTemp(tmpDir, "file")
if err != nil {
t.Fatal(err)
}
text := []byte("This is a test file.")
if _, err := tmpfile.Write(text); err != nil {
t.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
t.Fatal(err)
}

// Create a buffer to use as the writer
var buf bytes.Buffer

// Call the function
err = client.walkAndSend(&buf, tmpDir)
if err != nil {
t.Errorf("walkAndSend() error = %v", err)
}

// Check if the file content was written to the buffer
if !bytes.Contains(buf.Bytes(), text) {
t.Errorf("walkAndSend() did not write file content to writer")
}

// Check if the directory change commands were written to the buffer
if !bytes.Contains(buf.Bytes(), []byte("D")) || !bytes.Contains(buf.Bytes(), []byte("E")) {
t.Errorf("walkAndSend() did not write directory change commands to writer")
}
}

Loading