diff --git a/instruqt/team.go b/instruqt/team.go index 472661e..5195d14 100644 --- a/instruqt/team.go +++ b/instruqt/team.go @@ -15,7 +15,14 @@ package instruqt import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/pem" "fmt" + "net/url" "github.com/shurcooL/graphql" ) @@ -45,3 +52,60 @@ func (c *Client) GetTPGPublicKey() (string, error) { return string(q.Team.TPGPublicKey), nil } + +// EncryptPII encrypts PII using the public key fetched from the GetTPGPublicKey function. +// It takes a string representing the PII data, encodes it, and then encrypts it using RSA. +func (c *Client) EncryptPII(encodedPII string) (string, error) { + // Fetch the public key using the GetTPGPublicKey function + publicKeyPEM, err := c.GetTPGPublicKey() + if err != nil { + return "", fmt.Errorf("failed to get public key: %v", err) + } + + // Decode the PEM public key + block, _ := pem.Decode([]byte(publicKeyPEM)) + if block == nil || block.Type != "RSA PUBLIC KEY" { + return "", fmt.Errorf("failed to decode PEM block containing public key") + } + + // Parse the public key + publicKey, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return "", fmt.Errorf("failed to parse DER encoded public key: %v", err) + } + + // Assert the public key is of type *rsa.PublicKey + rsaPublicKey, ok := publicKey.(*rsa.PublicKey) + if !ok { + return "", fmt.Errorf("not an RSA public key") + } + + // Encrypt the PII + hash := sha256.New() + encryptedPII, err := rsa.EncryptOAEP(hash, rand.Reader, rsaPublicKey, []byte(encodedPII), nil) + if err != nil { + return "", fmt.Errorf("failed to encrypt PII: %v", err) + } + + // Encode the encrypted data to base64 + encryptedPIIBase64 := base64.StdEncoding.EncodeToString(encryptedPII) + return encryptedPIIBase64, nil +} + +// EncryptUserPII creates PII data (first name, last name, and email) and encrypts it using the public key. +func (c *Client) EncryptUserPII(firstName, lastName, email string) (string, error) { + // Prepare the PII data + piiData := url.Values{ + "fn": {firstName}, + "ln": {lastName}, + "e": {email}, + } + + // Encrypt the PII data + encryptedPII, err := c.EncryptPII(piiData.Encode()) + if err != nil { + return "", err + } + + return encryptedPII, nil +} \ No newline at end of file diff --git a/instruqt/team_test.go b/instruqt/team_test.go index e89e045..d8cfd1c 100644 --- a/instruqt/team_test.go +++ b/instruqt/team_test.go @@ -15,6 +15,11 @@ package instruqt import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "net/url" "testing" "github.com/shurcooL/graphql" @@ -40,3 +45,71 @@ func TestGetTPGPublicKey(t *testing.T) { assert.Equal(t, expectedPublicKey, publicKey) mockClient.AssertExpectations(t) } + +func TestEncryptPII(t *testing.T) { + // Generate a temporary RSA key pair for testing + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + // Extract the public key and encode it in PEM format + publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) + assert.NoError(t, err) + publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: publicKeyBytes, + }) + + // Create the mock client + mockClient := new(MockGraphQLClient) + client := &Client{ + GraphQLClient: mockClient, + } + mockClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + query := args.Get(1).(*teamQuery) + query.Team.TPGPublicKey = graphql.String(publicKeyPEM) + }).Return(nil) + + // Create the PII data to be encrypted + data := url.Values{} + data.Set("email", "test@example.com") + data.Set("first_name", "John") + data.Set("last_name", "Doe") + + // Call the EncryptPII function + encryptedPII, err := client.EncryptPII(data.Encode()) + assert.NoError(t, err) + + // Ensure the encrypted PII is a non-empty string + assert.NotEmpty(t, encryptedPII) +} + +func TestEncryptUserPII(t *testing.T) { + // Generate a temporary RSA key pair for testing + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + // Extract the public key and encode it in PEM format + publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) + assert.NoError(t, err) + publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: publicKeyBytes, + }) + + // Create the mock client + mockClient := new(MockGraphQLClient) + client := &Client{ + GraphQLClient: mockClient, + } + mockClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + query := args.Get(1).(*teamQuery) + query.Team.TPGPublicKey = graphql.String(publicKeyPEM) + }).Return(nil) + + // Call the EncryptUserPII function + encryptedPII, err := client.EncryptUserPII("John", "Doe", "test@example.com") + assert.NoError(t, err) + + // Ensure the encrypted PII is a non-empty string + assert.NotEmpty(t, encryptedPII) +}