Skip to content

Commit

Permalink
Add Encryption for GPQA (#3216)
Browse files Browse the repository at this point in the history
liamjxu authored Dec 20, 2024
1 parent b01f5f6 commit be8ac6b
Showing 8 changed files with 669 additions and 11 deletions.
16 changes: 11 additions & 5 deletions helm-frontend/src/components/Instances.tsx
Original file line number Diff line number Diff line change
@@ -23,9 +23,15 @@ interface Props {
runName: string;
suite: string;
metricFieldMap: MetricFieldMap;
userAgreed: boolean;
}

export default function Instances({ runName, suite, metricFieldMap }: Props) {
export default function Instances({
runName,
suite,
metricFieldMap,
userAgreed,
}: Props) {
const [searchParams, setSearchParams] = useSearchParams();
const [instances, setInstances] = useState<Instance[]>([]);
const [displayPredictionsMap, setDisplayPredictionsMap] = useState<
@@ -43,9 +49,9 @@ export default function Instances({ runName, suite, metricFieldMap }: Props) {

const [instancesResp, displayPredictions, displayRequests] =
await Promise.all([
getInstances(runName, signal, suite),
getDisplayPredictionsByName(runName, signal, suite),
getDisplayRequestsByName(runName, signal, suite),
getInstances(runName, signal, suite, userAgreed),
getDisplayPredictionsByName(runName, signal, suite, userAgreed),
getDisplayRequestsByName(runName, signal, suite, userAgreed),
]);
setInstances(instancesResp);

@@ -93,7 +99,7 @@ export default function Instances({ runName, suite, metricFieldMap }: Props) {
void fetchData();

return () => controller.abort();
}, [runName, suite]);
}, [runName, suite, userAgreed]);

const pagedInstances = instances.slice(
(currentInstancesPage - 1) * INSTANCES_PAGE_SIZE,
49 changes: 49 additions & 0 deletions helm-frontend/src/routes/Run.tsx
Original file line number Diff line number Diff line change
@@ -37,6 +37,9 @@ export default function Run() {
MetricFieldMap | undefined
>({});

const [agreeInput, setAgreeInput] = useState("");
const [userAgreed, setUserAgreed] = useState(false);

useEffect(() => {
const controller = new AbortController();
async function fetchData() {
@@ -93,6 +96,16 @@ export default function Run() {
return <Loading />;
}

// Handler for agreement
const handleAgreement = () => {
if (agreeInput.trim() === "Yes, I agree") {
setUserAgreed(true);
} else {
setUserAgreed(false);
alert("Please type 'Yes, I agree' exactly.");
}
};

return (
<>
<div className="flex justify-between gap-8 mb-12">
@@ -178,11 +191,47 @@ export default function Run() {
</Tab>
</Tabs>
</div>

{activeTab === 0 && runName.includes("gpqa") && !userAgreed && (
<div className="mb-8">
<hr className="my-4" />
<p className="mb-4">
The GPQA dataset instances are encrypted by default to comply with
the following request:
</p>
<blockquote className="italic border-l-4 border-gray-300 pl-4 text-gray-700 mb-4">
“We ask that you do not reveal examples from this dataset in plain
text or images online, to minimize the risk of these instances being
included in foundation model training corpora.”
</blockquote>
<p className="mb-4">
If you agree to this condition, please type{" "}
<strong>"Yes, I agree"</strong> in the box below and then click{" "}
<strong>Decrypt</strong>.
</p>
<div className="flex gap-2 mt-2">
<input
type="text"
value={agreeInput}
onChange={(e) => setAgreeInput(e.target.value)}
className="input input-bordered"
placeholder='Type "Yes, I agree"'
/>
<button onClick={handleAgreement} className="btn btn-primary">
Decrypt
</button>
</div>
<hr className="my-4" />
</div>
)}

{activeTab === 0 ? (
<Instances
key={userAgreed ? "instances-agreed" : "instances-not-agreed"}
runName={runName}
suite={runSuite}
metricFieldMap={metricFieldMap}
userAgreed={userAgreed} // Pass the boolean to Instances
/>
) : (
<RunMetrics
74 changes: 72 additions & 2 deletions helm-frontend/src/services/getDisplayPredictionsByName.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,93 @@
import type DisplayPrediction from "@/types/DisplayPrediction";
import { EncryptionDataMap } from "@/types/EncryptionDataMap";
import getBenchmarkEndpoint from "@/utils/getBenchmarkEndpoint";
import getBenchmarkSuite from "@/utils/getBenchmarkSuite";

async function decryptField(
ciphertext: string,
key: string,
iv: string,
tag: string,
): Promise<string> {
const decodeBase64 = (str: string) =>
Uint8Array.from(atob(str), (c) => c.charCodeAt(0));

const cryptoKey = await window.crypto.subtle.importKey(
"raw",
decodeBase64(key),
"AES-GCM",
true,
["decrypt"],
);

const combinedCiphertext = new Uint8Array([
...decodeBase64(ciphertext),
...decodeBase64(tag),
]);

const ivArray = decodeBase64(iv);

const decrypted = await window.crypto.subtle.decrypt(
{ name: "AES-GCM", iv: ivArray },
cryptoKey,
combinedCiphertext,
);

return new TextDecoder().decode(decrypted);
}

export default async function getDisplayPredictionsByName(
runName: string,
signal: AbortSignal,
suite?: string,
userAgreed?: boolean,
): Promise<DisplayPrediction[]> {
try {
const displayPrediction = await fetch(
const response = await fetch(
getBenchmarkEndpoint(
`/runs/${
suite || getBenchmarkSuite()
}/${runName}/display_predictions.json`,
),
{ signal },
);
const displayPredictions = (await response.json()) as DisplayPrediction[];

if (runName.includes("gpqa") && userAgreed) {
const encryptionResponse = await fetch(
getBenchmarkEndpoint(
`/runs/${
suite || getBenchmarkSuite()
}/${runName}/encryption_data.json`,
),
{ signal },
);
const encryptionData =
(await encryptionResponse.json()) as EncryptionDataMap;

for (const prediction of displayPredictions) {
const encryptedText = prediction.predicted_text;
const encryptionDetails = encryptionData[encryptedText];

if (encryptionDetails) {
try {
prediction.predicted_text = await decryptField(
encryptionDetails.ciphertext,
encryptionDetails.key,
encryptionDetails.iv,
encryptionDetails.tag,
);
} catch (error) {
console.error(
`Failed to decrypt predicted_text for instance_id: ${prediction.instance_id}`,
error,
);
}
}
}
}

return (await displayPrediction.json()) as DisplayPrediction[];
return displayPredictions;
} catch (error) {
if (error instanceof Error && error.name === "AbortError") {
console.log(error);
75 changes: 73 additions & 2 deletions helm-frontend/src/services/getDisplayRequestsByName.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,94 @@
import type DisplayRequest from "@/types/DisplayRequest";
import { EncryptionDataMap } from "@/types/EncryptionDataMap";
import getBenchmarkEndpoint from "@/utils/getBenchmarkEndpoint";
import getBenchmarkSuite from "@/utils/getBenchmarkSuite";

// Helper function for decryption
async function decryptField(
ciphertext: string,
key: string,
iv: string,
tag: string,
): Promise<string> {
const decodeBase64 = (str: string) =>
Uint8Array.from(atob(str), (c) => c.charCodeAt(0));

const cryptoKey = await window.crypto.subtle.importKey(
"raw",
decodeBase64(key),
"AES-GCM",
true,
["decrypt"],
);

const combinedCiphertext = new Uint8Array([
...decodeBase64(ciphertext),
...decodeBase64(tag),
]);

const ivArray = decodeBase64(iv);

const decrypted = await window.crypto.subtle.decrypt(
{ name: "AES-GCM", iv: ivArray },
cryptoKey,
combinedCiphertext,
);

return new TextDecoder().decode(decrypted);
}

export default async function getDisplayRequestsByName(
runName: string,
signal: AbortSignal,
suite?: string,
userAgreed?: boolean,
): Promise<DisplayRequest[]> {
try {
const displayRequest = await fetch(
const response = await fetch(
getBenchmarkEndpoint(
`/runs/${
suite || getBenchmarkSuite()
}/${runName}/display_requests.json`,
),
{ signal },
);
const displayRequests = (await response.json()) as DisplayRequest[];

if (runName.startsWith("gpqa") && userAgreed) {
const encryptionResponse = await fetch(
getBenchmarkEndpoint(
`/runs/${
suite || getBenchmarkSuite()
}/${runName}/encryption_data.json`,
),
{ signal },
);
const encryptionData =
(await encryptionResponse.json()) as EncryptionDataMap;

for (const request of displayRequests) {
const encryptedPrompt = request.request.prompt;
const encryptionDetails = encryptionData[encryptedPrompt];

if (encryptionDetails) {
try {
request.request.prompt = await decryptField(
encryptionDetails.ciphertext,
encryptionDetails.key,
encryptionDetails.iv,
encryptionDetails.tag,
);
} catch (error) {
console.error(
`Failed to decrypt prompt for instance_id: ${request.instance_id}`,
error,
);
}
}
}
}

return (await displayRequest.json()) as DisplayRequest[];
return displayRequests;
} catch (error) {
if (error instanceof Error && error.name !== "AbortError") {
console.log(error);
80 changes: 78 additions & 2 deletions helm-frontend/src/services/getInstances.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,97 @@
import Instance from "@/types/Instance";
import { EncryptionDataMap } from "@/types/EncryptionDataMap";
import getBenchmarkEndpoint from "@/utils/getBenchmarkEndpoint";
import getBenchmarkSuite from "@/utils/getBenchmarkSuite";

// Helper function for decryption
async function decryptField(
ciphertext: string,
key: string,
iv: string,
tag: string,
): Promise<string> {
// Convert Base64 strings to Uint8Array
const decodeBase64 = (str: string) =>
Uint8Array.from(atob(str), (c) => c.charCodeAt(0));

const cryptoKey = await window.crypto.subtle.importKey(
"raw",
decodeBase64(key),
"AES-GCM",
true,
["decrypt"],
);

const combinedCiphertext = new Uint8Array([
...decodeBase64(ciphertext),
...decodeBase64(tag),
]);

const ivArray = decodeBase64(iv);

const decrypted = await window.crypto.subtle.decrypt(
{ name: "AES-GCM", iv: ivArray },
cryptoKey,
combinedCiphertext,
);

return new TextDecoder().decode(decrypted);
}

export default async function getInstancesByRunName(
runName: string,
signal: AbortSignal,
suite?: string,
userAgreed?: boolean,
): Promise<Instance[]> {
try {
const instances = await fetch(
const response = await fetch(
getBenchmarkEndpoint(
`/runs/${suite || getBenchmarkSuite()}/${runName}/instances.json`,
),
{ signal },
);
const instances = (await response.json()) as Instance[];

if (runName.includes("gpqa") && userAgreed) {
const encryptionResponse = await fetch(
getBenchmarkEndpoint(
`/runs/${
suite || getBenchmarkSuite()
}/${runName}/encryption_data.json`,
),
{ signal },
);
const encryptionData =
(await encryptionResponse.json()) as EncryptionDataMap;

for (const instance of instances) {
const inputEncryption = encryptionData[instance.input.text];
if (inputEncryption) {
instance.input.text = "encrypted";
instance.input.text = await decryptField(
inputEncryption.ciphertext,
inputEncryption.key,
inputEncryption.iv,
inputEncryption.tag,
);
}

for (const reference of instance.references) {
const referenceEncryption = encryptionData[reference.output.text];
if (referenceEncryption) {
reference.output.text = await decryptField(
referenceEncryption.ciphertext,
referenceEncryption.key,
referenceEncryption.iv,
referenceEncryption.tag,
);
}
}
}
}

return (await instances.json()) as Instance[];
return instances;
} catch (error) {
if (error instanceof Error && error.name !== "AbortError") {
console.log(error);
8 changes: 8 additions & 0 deletions helm-frontend/src/types/EncryptionDataMap.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
export default interface EncryptionDetails {
ciphertext: string;
key: string;
iv: string;
tag: string;
}

export type EncryptionDataMap = Record<string, EncryptionDetails>;
191 changes: 191 additions & 0 deletions scripts/decrypt_scenario_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import argparse
import dataclasses
import json
import os
import base64
from typing import Dict, Optional
from tqdm import tqdm
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend

from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.scenario_state import ScenarioState
from helm.benchmark.scenarios.scenario import Instance, Reference
from helm.common.codec import from_json, to_json
from helm.common.hierarchical_logger import hlog
from helm.common.request import Request, RequestResult

_SCENARIO_STATE_FILE_NAME = "scenario_state.json"
_DECRYPTED_SCENARIO_STATE_FILE_NAME = "decrypted_scenario_state.json"
_DISPLAY_ENCRYPTION_DATA_JSON_FILE_NAME = "encryption_data.json"


class HELMDecryptor:
def __init__(self, encryption_data_mapping: Dict[str, Dict[str, str]]):
"""
encryption_data_mapping is a dict like:
{
"[encrypted_text_0]": {
"ciphertext": "...",
"key": "...",
"iv": "...",
"tag": "..."
},
...
}
"""
self.encryption_data_mapping = encryption_data_mapping

def decrypt_text(self, text: str) -> str:
if text.startswith("[encrypted_text_") and text.endswith("]"):
data = self.encryption_data_mapping.get(text)
if data is None:
# If not found in encryption data, return as-is or raise error
raise ValueError(f"No decryption data found for {text}")

ciphertext = base64.b64decode(data["ciphertext"])
key = base64.b64decode(data["key"])
iv = base64.b64decode(data["iv"])
tag = base64.b64decode(data["tag"])

cipher = Cipher(algorithms.AES(key), modes.GCM(iv, tag), backend=default_backend())
decryptor = cipher.decryptor()
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
return plaintext.decode("utf-8")
else:
# Not an encrypted placeholder, return as is.
return text


def read_scenario_state(scenario_state_path: str) -> ScenarioState:
if not os.path.exists(scenario_state_path):
raise ValueError(f"Could not load ScenarioState from {scenario_state_path}")
with open(scenario_state_path) as f:
return from_json(f.read(), ScenarioState)


def write_scenario_state(scenario_state_path: str, scenario_state: ScenarioState) -> None:
with open(scenario_state_path, "w") as f:
f.write(to_json(scenario_state))


def read_encryption_data(encryption_data_path: str) -> Dict[str, Dict[str, str]]:
if not os.path.exists(encryption_data_path):
raise ValueError(f"Could not load encryption data from {encryption_data_path}")
with open(encryption_data_path) as f:
return json.load(f)


def decrypt_reference(reference: Reference, decryptor: HELMDecryptor) -> Reference:
decrypted_output = dataclasses.replace(reference.output, text=decryptor.decrypt_text(reference.output.text))
return dataclasses.replace(reference, output=decrypted_output)


def decrypt_instance(instance: Instance, decryptor: HELMDecryptor) -> Instance:
decrypted_input = dataclasses.replace(instance.input, text=decryptor.decrypt_text(instance.input.text))
decrypted_references = [decrypt_reference(reference, decryptor) for reference in instance.references]
return dataclasses.replace(instance, input=decrypted_input, references=decrypted_references)


def decrypt_request(request: Request, decryptor: HELMDecryptor) -> Request:
# The encryption script sets request.messages and multimodal_prompt to None, so we don't need to decrypt them
return dataclasses.replace(request, prompt=decryptor.decrypt_text(request.prompt))


def decrypt_output_mapping(
output_mapping: Optional[Dict[str, str]], decryptor: HELMDecryptor
) -> Optional[Dict[str, str]]:
if output_mapping is None:
return None
return {key: decryptor.decrypt_text(val) for key, val in output_mapping.items()}


def decrypt_result(result: Optional[RequestResult], decryptor: HELMDecryptor) -> Optional[RequestResult]:
if result is None:
return None

decrypted_completions = [
dataclasses.replace(completion, text=decryptor.decrypt_text(completion.text))
for completion in result.completions
]
return dataclasses.replace(result, completions=decrypted_completions)


def decrypt_request_state(request_state: RequestState, decryptor: HELMDecryptor) -> RequestState:
return dataclasses.replace(
request_state,
instance=decrypt_instance(request_state.instance, decryptor),
request=decrypt_request(request_state.request, decryptor),
output_mapping=decrypt_output_mapping(request_state.output_mapping, decryptor),
result=decrypt_result(request_state.result, decryptor),
)


def decrypt_scenario_state(scenario_state: ScenarioState, decryptor: HELMDecryptor) -> ScenarioState:
decrypted_request_states = [decrypt_request_state(rs, decryptor) for rs in scenario_state.request_states]
return dataclasses.replace(scenario_state, request_states=decrypted_request_states)


def modify_scenario_state_for_run(run_path: str) -> None:
scenario_state_path = os.path.join(run_path, _SCENARIO_STATE_FILE_NAME)
encryption_data_path = os.path.join(run_path, _DISPLAY_ENCRYPTION_DATA_JSON_FILE_NAME)

scenario_state = read_scenario_state(scenario_state_path)
encryption_data_mapping = read_encryption_data(encryption_data_path)
decryptor = HELMDecryptor(encryption_data_mapping)

decrypted_scenario_state = decrypt_scenario_state(scenario_state, decryptor)
decrypted_scenario_state_path = os.path.join(run_path, _DECRYPTED_SCENARIO_STATE_FILE_NAME)
write_scenario_state(decrypted_scenario_state_path, decrypted_scenario_state)


def modify_scenario_states_for_suite(run_suite_path: str, scenario: str) -> None:
scenario_prefix = scenario if scenario != "all" else ""
run_dir_names = sorted(
[
p
for p in os.listdir(run_suite_path)
if p != "eval_cache"
and p != "groups"
and os.path.isdir(os.path.join(run_suite_path, p))
and p.startswith(scenario_prefix)
]
)
for run_dir_name in tqdm(run_dir_names, disable=None):
scenario_state_path: str = os.path.join(run_suite_path, run_dir_name, _SCENARIO_STATE_FILE_NAME)
encryption_data_path = os.path.join(run_suite_path, run_dir_name, _DISPLAY_ENCRYPTION_DATA_JSON_FILE_NAME)
if not os.path.exists(scenario_state_path):
hlog(f"WARNING: {run_dir_name} doesn't have {_SCENARIO_STATE_FILE_NAME}, skipping")
continue
if not os.path.exists(encryption_data_path):
hlog(f"WARNING: {run_dir_name} doesn't have {_DISPLAY_ENCRYPTION_DATA_JSON_FILE_NAME}, skipping")
continue
run_path: str = os.path.join(run_suite_path, run_dir_name)
modify_scenario_state_for_run(run_path)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-o", "--output-path", type=str, help="Where the benchmarking output lives", default="benchmark_output"
)
parser.add_argument(
"--suite",
type=str,
help="Name of the suite this decryption should go under.",
)
parser.add_argument(
"--scenario",
type=str,
default="all",
help="Name of the scenario this decryption should go under. Default is all.",
)
args = parser.parse_args()
output_path = args.output_path
suite = args.suite
run_suite_path = os.path.join(output_path, "runs", suite)
modify_scenario_states_for_suite(run_suite_path, scenario=args.scenario)


if __name__ == "__main__":
main()
187 changes: 187 additions & 0 deletions scripts/encrypt_scenario_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
"""Encrypts prompts from scenario state.
This script modifies all scenario_state.json files in place within a suite to
encrypting all prompts, instance input text, and instance reference output text
from the `ScenarioState`s.
This is used when the scenario contains prompts that should not be displayed,
in order to reduce the chance of data leakage or to comply with data privacy
requirements.
After running this, you must re-run helm-summarize on the suite in order to
update other JSON files used by the web frontend."""

import argparse
import dataclasses
import os
import base64
from typing import Dict, Optional
from tqdm import tqdm
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend

from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.scenario_state import ScenarioState
from helm.benchmark.scenarios.scenario import Instance, Reference
from helm.common.codec import from_json, to_json
from helm.common.hierarchical_logger import hlog
from helm.common.request import Request, RequestResult
from helm.common.general import write


_SCENARIO_STATE_FILE_NAME = "scenario_state.json"
_DISPLAY_ENCRYPTION_DATA_JSON_FILE_NAME = "encryption_data.json"


class HELMEncryptor:
def __init__(self, key, iv):
self.key = key
self.iv = iv
self.encryption_data_mapping = {}
self.idx = 0

def encrypt_text(self, text: str) -> str:
cipher = Cipher(algorithms.AES(self.key), modes.GCM(self.iv), backend=default_backend())
encryptor = cipher.encryptor()
ciphertext = encryptor.update(text.encode()) + encryptor.finalize()
ret_text = f"[encrypted_text_{self.idx}]"

res = {
"ciphertext": base64.b64encode(ciphertext).decode(),
"key": base64.b64encode(self.key).decode(),
"iv": base64.b64encode(self.iv).decode(),
"tag": base64.b64encode(encryptor.tag).decode(),
}
assert ret_text not in self.encryption_data_mapping
self.encryption_data_mapping[ret_text] = res
self.idx += 1
return ret_text


def read_scenario_state(scenario_state_path: str) -> ScenarioState:
if not os.path.exists(scenario_state_path):
raise ValueError(f"Could not load ScenarioState from {scenario_state_path}")
with open(scenario_state_path) as f:
return from_json(f.read(), ScenarioState)


def write_scenario_state(scenario_state_path: str, scenario_state: ScenarioState) -> None:
with open(scenario_state_path, "w") as f:
f.write(to_json(scenario_state))


def encrypt_reference(reference: Reference) -> Reference:
global encryptor
encrypted_output = dataclasses.replace(reference.output, text=encryptor.encrypt_text(reference.output.text))
return dataclasses.replace(reference, output=encrypted_output)


def encrypt_instance(instance: Instance) -> Instance:
global encryptor
encrypted_input = dataclasses.replace(instance.input, text=encryptor.encrypt_text(instance.input.text))
encrypted_references = [encrypt_reference(reference) for reference in instance.references]
return dataclasses.replace(instance, input=encrypted_input, references=encrypted_references)


def encrypt_request(request: Request) -> Request:
global encryptor
return dataclasses.replace(
request, prompt=encryptor.encrypt_text(request.prompt), messages=None, multimodal_prompt=None
)


def encrypt_output_mapping(output_mapping: Optional[Dict[str, str]]) -> Optional[Dict[str, str]]:
if output_mapping is None:
return None
return {key: encryptor.encrypt_text(val) for key, val in output_mapping.items()}


def encrypt_result(result: Optional[RequestResult]) -> Optional[RequestResult]:
if result is None:
return None

encrypted_results = [
dataclasses.replace(completion, text=encryptor.encrypt_text(completion.text))
for completion in result.completions
]
return dataclasses.replace(result, completions=encrypted_results)


def encrypt_request_state(request_state: RequestState) -> RequestState:
return dataclasses.replace(
request_state,
instance=encrypt_instance(request_state.instance),
request=encrypt_request(request_state.request),
output_mapping=encrypt_output_mapping(request_state.output_mapping),
result=encrypt_result(request_state.result),
)


def encrypt_scenario_state(scenario_state: ScenarioState) -> ScenarioState:
encrypted_request_states = [encrypt_request_state(request_state) for request_state in scenario_state.request_states]
return dataclasses.replace(scenario_state, request_states=encrypted_request_states)


def modify_scenario_state_for_run(run_path: str) -> None:
scenario_state_path = os.path.join(run_path, _SCENARIO_STATE_FILE_NAME)
scenario_state = read_scenario_state(scenario_state_path)
encrypted_scenario_state = encrypt_scenario_state(scenario_state)
write_scenario_state(scenario_state_path, encrypted_scenario_state)


def modify_scenario_states_for_suite(run_suite_path: str, scenario: str) -> None:
"""Load the runs in the run suite path."""
# run_suite_path can contain subdirectories that are not runs (e.g. eval_cache, groups)
# so filter them out.
scenario_prefix = scenario if scenario != "all" else ""
run_dir_names = sorted(
[
p
for p in os.listdir(run_suite_path)
if p != "eval_cache"
and p != "groups"
and os.path.isdir(os.path.join(run_suite_path, p))
and p.startswith(scenario_prefix)
]
)
for run_dir_name in tqdm(run_dir_names, disable=None):
scenario_state_path: str = os.path.join(run_suite_path, run_dir_name, _SCENARIO_STATE_FILE_NAME)
if not os.path.exists(scenario_state_path):
hlog(f"WARNING: {run_dir_name} doesn't have {_SCENARIO_STATE_FILE_NAME}, skipping")
continue
run_path: str = os.path.join(run_suite_path, run_dir_name)
modify_scenario_state_for_run(run_path)

# Write the encryption data to a file
encryption_data_path = os.path.join(run_path, _DISPLAY_ENCRYPTION_DATA_JSON_FILE_NAME)
write(encryption_data_path, to_json(encryptor.encryption_data_mapping))


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-o", "--output-path", type=str, help="Where the benchmarking output lives", default="benchmark_output"
)
parser.add_argument(
"--suite",
type=str,
help="Name of the suite this encryption should go under.",
)
parser.add_argument(
"--scenario",
type=str,
default="all",
help="Name of the scenario this encryption should go under. Default is all.",
)
args = parser.parse_args()
output_path = args.output_path
suite = args.suite
run_suite_path = os.path.join(output_path, "runs", suite)
modify_scenario_states_for_suite(run_suite_path, scenario=args.scenario)


if __name__ == "__main__":
key = os.urandom(32) # 256-bit key
iv = os.urandom(12) # 96-bit IV (suitable for AES-GCM)
encryptor = HELMEncryptor(key, iv)
main()

0 comments on commit be8ac6b

Please sign in to comment.