From 003e8aed3be28c56ea6b9b7d784cc142d41e7c19 Mon Sep 17 00:00:00 2001 From: Miguel Alonso Jr <76960110+miguelalonsojr@users.noreply.github.com> Date: Wed, 13 Oct 2021 09:28:10 -0400 Subject: [PATCH] Develop add training area replicator (#5568) * Added training area replicator to com.unity.ml-agents package. * Added num_areas to Unity RL Initialization proto. Added cli and config file support for num_areas. * Changed training area replicator to size grid automatically from number of areas. * Added tests for the training area replicator. * Added setup for tests for the training area replicator. * Added comments and updated create tutorial for training area replicator. * Updated CHANGELOG. * Fixed some failing tests. * Update com.unity.ml-agents/CHANGELOG.md Co-authored-by: Henry Peteet * Update docs/Learning-Environment-Create-New.md Co-authored-by: Henry Peteet * Update com.unity.ml-agents/Runtime/Areas/TrainingAreaReplicator.cs Co-authored-by: Henry Peteet * Addressed CR comments. Co-authored-by: Miguel Alonso Jr Co-authored-by: Henry Peteet --- com.unity.ml-agents/CHANGELOG.md | 1 + com.unity.ml-agents/Runtime/Academy.cs | 6 ++ com.unity.ml-agents/Runtime/Areas.meta | 8 ++ .../Runtime/Areas/TrainingAreaReplicator.cs | 88 +++++++++++++++++++ .../Areas/TrainingAreaReplicator.cs.meta | 11 +++ .../Runtime/Communicator/GrpcExtensions.cs | 1 + .../Runtime/Communicator/ICommunicator.cs | 5 ++ .../UnityRlInitializationInput.cs | 39 +++++++- .../Runtime/Unity.ML-Agents.asmdef | 3 +- com.unity.ml-agents/Tests/Editor/Areas.meta | 3 + .../Areas/TrainingAreaReplicatorTests.cs | 61 +++++++++++++ .../Areas/TrainingAreaReplicatorTests.cs.meta | 11 +++ .../Unity.ML-Agents.Editor.Tests.asmdef | 1 + docs/Learning-Environment-Create-New.md | 8 ++ docs/Python-API-Documentation.md | 2 +- .../unity_rl_initialization_input_pb2.py | 11 ++- .../unity_rl_initialization_input_pb2.pyi | 6 +- ml-agents-envs/mlagents_envs/environment.py | 2 + ml-agents/mlagents/trainers/cli_utils.py | 9 ++ ml-agents/mlagents/trainers/learn.py | 9 +- ml-agents/mlagents/trainers/settings.py | 6 ++ .../mlagents/trainers/tests/test_learn.py | 10 ++- .../mlagents/trainers/tests/test_settings.py | 4 + .../unity_rl_initialization_input.proto | 3 + 24 files changed, 295 insertions(+), 13 deletions(-) create mode 100644 com.unity.ml-agents/Runtime/Areas.meta create mode 100644 com.unity.ml-agents/Runtime/Areas/TrainingAreaReplicator.cs create mode 100644 com.unity.ml-agents/Runtime/Areas/TrainingAreaReplicator.cs.meta create mode 100644 com.unity.ml-agents/Tests/Editor/Areas.meta create mode 100644 com.unity.ml-agents/Tests/Editor/Areas/TrainingAreaReplicatorTests.cs create mode 100644 com.unity.ml-agents/Tests/Editor/Areas/TrainingAreaReplicatorTests.cs.meta diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index ebdfcc5484..a63cee98c2 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to ### Major Changes #### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) +- Added a new feature to replicate training areas dynamically during runtime. (#5568) #### ml-agents / ml-agents-envs / gym-unity (Python) diff --git a/com.unity.ml-agents/Runtime/Academy.cs b/com.unity.ml-agents/Runtime/Academy.cs index 134cac5f11..85cab21b26 100644 --- a/com.unity.ml-agents/Runtime/Academy.cs +++ b/com.unity.ml-agents/Runtime/Academy.cs @@ -183,6 +183,11 @@ public int InferenceSeed set { m_InferenceSeed = value; } } + // Number of training areas to instantiate + int m_NumAreas; + + public int NumAreas => m_NumAreas; + /// /// Returns the RLCapabilities of the python client that the unity process is connected to. /// @@ -451,6 +456,7 @@ out var unityRlInitParameters UnityEngine.Random.InitState(unityRlInitParameters.seed); // We might have inference-only Agents, so set the seed for them too. m_InferenceSeed = unityRlInitParameters.seed; + m_NumAreas = unityRlInitParameters.numAreas; TrainerCapabilities = unityRlInitParameters.TrainerCapabilities; TrainerCapabilities.WarnOnPythonMissingBaseRLCapabilities(); } diff --git a/com.unity.ml-agents/Runtime/Areas.meta b/com.unity.ml-agents/Runtime/Areas.meta new file mode 100644 index 0000000000..d00b0cf67c --- /dev/null +++ b/com.unity.ml-agents/Runtime/Areas.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 4774a04ed09a1405cb957aace235adcb +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Areas/TrainingAreaReplicator.cs b/com.unity.ml-agents/Runtime/Areas/TrainingAreaReplicator.cs new file mode 100644 index 0000000000..47d823625e --- /dev/null +++ b/com.unity.ml-agents/Runtime/Areas/TrainingAreaReplicator.cs @@ -0,0 +1,88 @@ +using System; +using Unity.Mathematics; +using UnityEngine; + +namespace Unity.MLAgents.Areas +{ + /// + /// The Training Ares Replicator allows for a training area object group to be replicated dynamically during runtime. + /// + public class TrainingAreaReplicator : MonoBehaviour + { + public GameObject baseArea; + public int numAreas = 1; + public float separation = 10f; + + int3 m_GridSize = new int3(1, 1, 1); + int m_areaCount = 0; + string m_TrainingAreaName; + + public int3 GridSize => m_GridSize; + public string TrainingAreaName => m_TrainingAreaName; + + public void Awake() + { + // Computes the Grid Size on Awake + ComputeGridSize(); + // Sets the TrainingArea name to the name of the base area. + m_TrainingAreaName = baseArea.name; + } + + public void OnEnable() + { + // Adds the training are replicas during OnEnable to ensure they are added before the Academy begins its work. + AddEnvironments(); + } + + /// + /// Computes the Grid Size for replicating the training area. + /// + void ComputeGridSize() + { + // check if running inference, if so, use the num areas set through the component, + // otherwise, pull it from the academy + if (Academy.Instance.Communicator != null) + numAreas = Academy.Instance.NumAreas; + + var rootNumAreas = Mathf.Pow(numAreas, 1.0f / 3.0f); + m_GridSize.x = Mathf.CeilToInt(rootNumAreas); + m_GridSize.y = Mathf.CeilToInt(rootNumAreas); + var zSize = Mathf.CeilToInt((float)numAreas / (m_GridSize.x * m_GridSize.y)); + m_GridSize.z = zSize == 0 ? 1 : zSize; + } + + /// + /// Adds replicas of the training area to the scene. + /// + /// + void AddEnvironments() + { + if (numAreas > m_GridSize.x * m_GridSize.y * m_GridSize.z) + { + throw new UnityAgentsException("The number of training areas that you have specified exceeds the size of the grid."); + } + + for (int z = 0; z < m_GridSize.z; z++) + { + for (int y = 0; y < m_GridSize.y; y++) + { + for (int x = 0; x < m_GridSize.x; x++) + { + if (m_areaCount == 0) + { + // Skip this first area since it already exists. + m_areaCount = 1; + } + else if (m_areaCount < numAreas) + { + m_areaCount++; + var area = Instantiate(baseArea, new Vector3(x * separation, y * separation, z * separation), Quaternion.identity); + area.name = m_TrainingAreaName; + } + } + } + } + } + } +} + diff --git a/com.unity.ml-agents/Runtime/Areas/TrainingAreaReplicator.cs.meta b/com.unity.ml-agents/Runtime/Areas/TrainingAreaReplicator.cs.meta new file mode 100644 index 0000000000..84ac36d789 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Areas/TrainingAreaReplicator.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 7fc26c3bda6fe4937b2264ffe43190b7 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs index e49bd876e7..e5a97cd167 100644 --- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -298,6 +298,7 @@ public static UnityRLInitParameters ToUnityRLInitParameters(this UnityRLInitiali return new UnityRLInitParameters { seed = inputProto.Seed, + numAreas = inputProto.NumAreas, pythonLibraryVersion = inputProto.PackageVersion, pythonCommunicationVersion = inputProto.CommunicationVersion, TrainerCapabilities = inputProto.Capabilities.ToRLCapabilities() diff --git a/com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs b/com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs index 092c4af84d..035de2d3c8 100644 --- a/com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs +++ b/com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs @@ -39,6 +39,11 @@ internal struct UnityRLInitParameters /// public int seed; + /// + /// The number of areas to replicate if Training Area Replication is used in the scene. + /// + public int numAreas; + /// /// The library version of the python process. /// diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationInput.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationInput.cs index f1d92e4a95..1b83a17e68 100644 --- a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationInput.cs +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/UnityRlInitializationInput.cs @@ -27,16 +27,16 @@ static UnityRlInitializationInputReflection() { "CkZtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X3Js", "X2luaXRpYWxpemF0aW9uX2lucHV0LnByb3RvEhRjb21tdW5pY2F0b3Jfb2Jq", "ZWN0cxo1bWxhZ2VudHNfZW52cy9jb21tdW5pY2F0b3Jfb2JqZWN0cy9jYXBh", - "YmlsaXRpZXMucHJvdG8irQEKH1VuaXR5UkxJbml0aWFsaXphdGlvbklucHV0", + "YmlsaXRpZXMucHJvdG8iwAEKH1VuaXR5UkxJbml0aWFsaXphdGlvbklucHV0", "UHJvdG8SDAoEc2VlZBgBIAEoBRIdChVjb21tdW5pY2F0aW9uX3ZlcnNpb24Y", "AiABKAkSFwoPcGFja2FnZV92ZXJzaW9uGAMgASgJEkQKDGNhcGFiaWxpdGll", "cxgEIAEoCzIuLmNvbW11bmljYXRvcl9vYmplY3RzLlVuaXR5UkxDYXBhYmls", - "aXRpZXNQcm90b0IlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2Jq", - "ZWN0c2IGcHJvdG8z")); + "aXRpZXNQcm90bxIRCgludW1fYXJlYXMYBSABKAVCJaoCIlVuaXR5Lk1MQWdl", + "bnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw==")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.CapabilitiesReflection.Descriptor, }, new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLInitializationInputProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLInitializationInputProto.Parser, new[]{ "Seed", "CommunicationVersion", "PackageVersion", "Capabilities" }, null, null, null) + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLInitializationInputProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLInitializationInputProto.Parser, new[]{ "Seed", "CommunicationVersion", "PackageVersion", "Capabilities", "NumAreas" }, null, null, null) })); } #endregion @@ -75,6 +75,7 @@ public UnityRLInitializationInputProto(UnityRLInitializationInputProto other) : communicationVersion_ = other.communicationVersion_; packageVersion_ = other.packageVersion_; Capabilities = other.capabilities_ != null ? other.Capabilities.Clone() : null; + numAreas_ = other.numAreas_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -136,6 +137,20 @@ public string PackageVersion { } } + /// Field number for the "num_areas" field. + public const int NumAreasFieldNumber = 5; + private int numAreas_; + /// + /// The number of training areas to instantiate + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumAreas { + get { return numAreas_; } + set { + numAreas_ = value; + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as UnityRLInitializationInputProto); @@ -153,6 +168,7 @@ public bool Equals(UnityRLInitializationInputProto other) { if (CommunicationVersion != other.CommunicationVersion) return false; if (PackageVersion != other.PackageVersion) return false; if (!object.Equals(Capabilities, other.Capabilities)) return false; + if (NumAreas != other.NumAreas) return false; return Equals(_unknownFields, other._unknownFields); } @@ -163,6 +179,7 @@ public override int GetHashCode() { if (CommunicationVersion.Length != 0) hash ^= CommunicationVersion.GetHashCode(); if (PackageVersion.Length != 0) hash ^= PackageVersion.GetHashCode(); if (capabilities_ != null) hash ^= Capabilities.GetHashCode(); + if (NumAreas != 0) hash ^= NumAreas.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -192,6 +209,10 @@ public void WriteTo(pb::CodedOutputStream output) { output.WriteRawTag(34); output.WriteMessage(Capabilities); } + if (NumAreas != 0) { + output.WriteRawTag(40); + output.WriteInt32(NumAreas); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -212,6 +233,9 @@ public int CalculateSize() { if (capabilities_ != null) { size += 1 + pb::CodedOutputStream.ComputeMessageSize(Capabilities); } + if (NumAreas != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumAreas); + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -238,6 +262,9 @@ public void MergeFrom(UnityRLInitializationInputProto other) { } Capabilities.MergeFrom(other.Capabilities); } + if (other.NumAreas != 0) { + NumAreas = other.NumAreas; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -268,6 +295,10 @@ public void MergeFrom(pb::CodedInputStream input) { input.ReadMessage(capabilities_); break; } + case 40: { + NumAreas = input.ReadInt32(); + break; + } } } } diff --git a/com.unity.ml-agents/Runtime/Unity.ML-Agents.asmdef b/com.unity.ml-agents/Runtime/Unity.ML-Agents.asmdef index 8e0fbb8e69..4a54e71fee 100755 --- a/com.unity.ml-agents/Runtime/Unity.ML-Agents.asmdef +++ b/com.unity.ml-agents/Runtime/Unity.ML-Agents.asmdef @@ -2,7 +2,8 @@ "name": "Unity.ML-Agents", "references": [ "Unity.Barracuda", - "Unity.ML-Agents.CommunicatorObjects" + "Unity.ML-Agents.CommunicatorObjects", + "Unity.Mathematics" ], "optionalUnityReferences": [], "includePlatforms": [], diff --git a/com.unity.ml-agents/Tests/Editor/Areas.meta b/com.unity.ml-agents/Tests/Editor/Areas.meta new file mode 100644 index 0000000000..42901a0e6b --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Areas.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: d32a102dc1f004c33b05a30190a9d039 +timeCreated: 1632841906 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Areas/TrainingAreaReplicatorTests.cs b/com.unity.ml-agents/Tests/Editor/Areas/TrainingAreaReplicatorTests.cs new file mode 100644 index 0000000000..f1046ebf68 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Areas/TrainingAreaReplicatorTests.cs @@ -0,0 +1,61 @@ +using System.Linq; +using NUnit.Framework; +using Unity.Mathematics; +using Unity.MLAgents.Areas; +using UnityEngine; + +namespace Unity.MLAgents.Tests.Areas +{ + [TestFixture] + public class TrainingAreaReplicatorTests + { + private TrainingAreaReplicator m_Replicator; + + [SetUp] + public void Setup() + { + var gameObject = new GameObject(); + var trainingArea = new GameObject(); + trainingArea.name = "MyTrainingArea"; + m_Replicator = gameObject.AddComponent(); + m_Replicator.baseArea = trainingArea; + } + + private static object[] NumAreasCases = + { + new object[] {1}, + new object[] {2}, + new object[] {5}, + new object[] {7}, + new object[] {8}, + new object[] {64}, + new object[] {63}, + }; + + [TestCaseSource(nameof(NumAreasCases))] + public void TestComputeGridSize(int numAreas) + { + m_Replicator.numAreas = numAreas; + m_Replicator.Awake(); + m_Replicator.OnEnable(); + var m_CorrectGridSize = int3.zero; + var m_RootNumAreas = Mathf.Pow(numAreas, 1.0f / 3.0f); + m_CorrectGridSize.x = Mathf.CeilToInt(m_RootNumAreas); + m_CorrectGridSize.y = Mathf.CeilToInt(m_RootNumAreas); + m_CorrectGridSize.z = Mathf.CeilToInt((float)numAreas / (m_CorrectGridSize.x * m_CorrectGridSize.y)); + Assert.GreaterOrEqual(m_Replicator.GridSize.x * m_Replicator.GridSize.y * m_Replicator.GridSize.z, m_Replicator.numAreas); + Assert.AreEqual(m_CorrectGridSize, m_Replicator.GridSize); + } + + [Test] + public void TestAddEnvironments() + { + m_Replicator.numAreas = 10; + m_Replicator.Awake(); + m_Replicator.OnEnable(); + var trainingAreas = Resources.FindObjectsOfTypeAll().Where(obj => obj.name == m_Replicator.TrainingAreaName); + Assert.AreEqual(10, trainingAreas.Count()); + + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Areas/TrainingAreaReplicatorTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Areas/TrainingAreaReplicatorTests.cs.meta new file mode 100644 index 0000000000..4ebc4ba4d1 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Areas/TrainingAreaReplicatorTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 391a03d82068e44b5bba0ca55215b0c7 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Unity.ML-Agents.Editor.Tests.asmdef b/com.unity.ml-agents/Tests/Editor/Unity.ML-Agents.Editor.Tests.asmdef index 59135591a4..128105b400 100755 --- a/com.unity.ml-agents/Tests/Editor/Unity.ML-Agents.Editor.Tests.asmdef +++ b/com.unity.ml-agents/Tests/Editor/Unity.ML-Agents.Editor.Tests.asmdef @@ -4,6 +4,7 @@ "Unity.ML-Agents.Editor", "Unity.ML-Agents", "Unity.Barracuda", + "Unity.Mathematics", "Unity.ML-Agents.CommunicatorObjects", "Unity.ML-Agents.Runtime.Utils.Tests", "Unity.ML-Agents.Runtime.Sensor.Tests" diff --git a/docs/Learning-Environment-Create-New.md b/docs/Learning-Environment-Create-New.md index 867e889080..3e2e374cf8 100644 --- a/docs/Learning-Environment-Create-New.md +++ b/docs/Learning-Environment-Create-New.md @@ -475,6 +475,14 @@ RollerBall environment: 1. You can now instantiate copies of the TrainingArea prefab. Drag them into your scene, positioning them so that they do not overlap. +Alternatively, you can use the `TrainingAreaReplicator` to replicate training areas. Use the following steps: + +1. Create a new empty Game Object in the scene. +2. Click on the new object and add a TrainingAreaReplicator component to the empty Game Object through the inspector. +3. Drag the training area to `Base Area` in the Training Area Replicator. +4. Specify the number of areas to replicate and the separation between areas. +5. Hit play and the areas will be replicated automatically! + ## Optional: Training Using Concurrent Unity Instances Another level of parallelization comes by training using [concurrent Unity instances](ML-Agents-Overview.md#additional-features). diff --git a/docs/Python-API-Documentation.md b/docs/Python-API-Documentation.md index 73757cf160..54012acacc 100644 --- a/docs/Python-API-Documentation.md +++ b/docs/Python-API-Documentation.md @@ -623,7 +623,7 @@ class UnityEnvironment(BaseEnv) #### \_\_init\_\_ ```python - | __init__(file_name: Optional[str] = None, worker_id: int = 0, base_port: Optional[int] = None, seed: int = 0, no_graphics: bool = False, timeout_wait: int = 60, additional_args: Optional[List[str]] = None, side_channels: Optional[List[SideChannel]] = None, log_folder: Optional[str] = None) + | __init__(file_name: Optional[str] = None, worker_id: int = 0, base_port: Optional[int] = None, seed: int = 0, no_graphics: bool = False, timeout_wait: int = 60, additional_args: Optional[List[str]] = None, side_channels: Optional[List[SideChannel]] = None, log_folder: Optional[str] = None, num_areas: int = 1) ``` Starts a new unity environment and establishes a connection with the environment. diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/unity_rl_initialization_input_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/unity_rl_initialization_input_pb2.py index 7f5f4875f5..d111397ada 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/unity_rl_initialization_input_pb2.py +++ b/ml-agents-envs/mlagents_envs/communicator_objects/unity_rl_initialization_input_pb2.py @@ -20,7 +20,7 @@ name='mlagents_envs/communicator_objects/unity_rl_initialization_input.proto', package='communicator_objects', syntax='proto3', - serialized_pb=_b('\nFmlagents_envs/communicator_objects/unity_rl_initialization_input.proto\x12\x14\x63ommunicator_objects\x1a\x35mlagents_envs/communicator_objects/capabilities.proto\"\xad\x01\n\x1fUnityRLInitializationInputProto\x12\x0c\n\x04seed\x18\x01 \x01(\x05\x12\x1d\n\x15\x63ommunication_version\x18\x02 \x01(\t\x12\x17\n\x0fpackage_version\x18\x03 \x01(\t\x12\x44\n\x0c\x63\x61pabilities\x18\x04 \x01(\x0b\x32..communicator_objects.UnityRLCapabilitiesProtoB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') + serialized_pb=_b('\nFmlagents_envs/communicator_objects/unity_rl_initialization_input.proto\x12\x14\x63ommunicator_objects\x1a\x35mlagents_envs/communicator_objects/capabilities.proto\"\xc0\x01\n\x1fUnityRLInitializationInputProto\x12\x0c\n\x04seed\x18\x01 \x01(\x05\x12\x1d\n\x15\x63ommunication_version\x18\x02 \x01(\t\x12\x17\n\x0fpackage_version\x18\x03 \x01(\t\x12\x44\n\x0c\x63\x61pabilities\x18\x04 \x01(\x0b\x32..communicator_objects.UnityRLCapabilitiesProto\x12\x11\n\tnum_areas\x18\x05 \x01(\x05\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') , dependencies=[mlagents__envs_dot_communicator__objects_dot_capabilities__pb2.DESCRIPTOR,]) @@ -62,6 +62,13 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='num_areas', full_name='communicator_objects.UnityRLInitializationInputProto.num_areas', index=4, + number=5, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -75,7 +82,7 @@ oneofs=[ ], serialized_start=152, - serialized_end=325, + serialized_end=344, ) _UNITYRLINITIALIZATIONINPUTPROTO.fields_by_name['capabilities'].message_type = mlagents__envs_dot_communicator__objects_dot_capabilities__pb2._UNITYRLCAPABILITIESPROTO diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/unity_rl_initialization_input_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/unity_rl_initialization_input_pb2.pyi index d2ff71da29..f502997610 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/unity_rl_initialization_input_pb2.pyi +++ b/ml-agents-envs/mlagents_envs/communicator_objects/unity_rl_initialization_input_pb2.pyi @@ -33,6 +33,7 @@ class UnityRLInitializationInputProto(google___protobuf___message___Message): seed = ... # type: builtin___int communication_version = ... # type: typing___Text package_version = ... # type: typing___Text + num_areas = ... # type: builtin___int @property def capabilities(self) -> mlagents_envs___communicator_objects___capabilities_pb2___UnityRLCapabilitiesProto: ... @@ -43,6 +44,7 @@ class UnityRLInitializationInputProto(google___protobuf___message___Message): communication_version : typing___Optional[typing___Text] = None, package_version : typing___Optional[typing___Text] = None, capabilities : typing___Optional[mlagents_envs___communicator_objects___capabilities_pb2___UnityRLCapabilitiesProto] = None, + num_areas : typing___Optional[builtin___int] = None, ) -> None: ... @classmethod def FromString(cls, s: builtin___bytes) -> UnityRLInitializationInputProto: ... @@ -50,7 +52,7 @@ class UnityRLInitializationInputProto(google___protobuf___message___Message): def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... if sys.version_info >= (3,): def HasField(self, field_name: typing_extensions___Literal[u"capabilities"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"capabilities",u"communication_version",u"package_version",u"seed"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"capabilities",u"communication_version",u"num_areas",u"package_version",u"seed"]) -> None: ... else: def HasField(self, field_name: typing_extensions___Literal[u"capabilities",b"capabilities"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"capabilities",b"capabilities",u"communication_version",b"communication_version",u"package_version",b"package_version",u"seed",b"seed"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"capabilities",b"capabilities",u"communication_version",b"communication_version",u"num_areas",b"num_areas",u"package_version",b"package_version",u"seed",b"seed"]) -> None: ... diff --git a/ml-agents-envs/mlagents_envs/environment.py b/ml-agents-envs/mlagents_envs/environment.py index 776c5d1030..18731a20bb 100644 --- a/ml-agents-envs/mlagents_envs/environment.py +++ b/ml-agents-envs/mlagents_envs/environment.py @@ -153,6 +153,7 @@ def __init__( additional_args: Optional[List[str]] = None, side_channels: Optional[List[SideChannel]] = None, log_folder: Optional[str] = None, + num_areas: int = 1, ): """ Starts a new unity environment and establishes a connection with the environment. @@ -229,6 +230,7 @@ def __init__( communication_version=self.API_VERSION, package_version=mlagents_envs.__version__, capabilities=UnityEnvironment._get_capabilities_proto(), + num_areas=num_areas, ) try: aca_output = self._send_academy_parameters(rl_init_parameters_in) diff --git a/ml-agents/mlagents/trainers/cli_utils.py b/ml-agents/mlagents/trainers/cli_utils.py index 5f87ea094c..5884c3a5c5 100644 --- a/ml-agents/mlagents/trainers/cli_utils.py +++ b/ml-agents/mlagents/trainers/cli_utils.py @@ -161,6 +161,15 @@ def _create_parser() -> argparse.ArgumentParser: "from when training", action=DetectDefault, ) + + argparser.add_argument( + "--num-areas", + default=1, + type=int, + help="The number of parallel training areas in each Unity environment instance.", + action=DetectDefault, + ) + argparser.add_argument( "--debug", default=False, diff --git a/ml-agents/mlagents/trainers/learn.py b/ml-agents/mlagents/trainers/learn.py index e3015ffddc..3f8fc80632 100644 --- a/ml-agents/mlagents/trainers/learn.py +++ b/ml-agents/mlagents/trainers/learn.py @@ -52,10 +52,11 @@ def parse_command_line(argv: Optional[List[str]] = None) -> RunOptions: return RunOptions.from_argparse(args) -def run_training(run_seed: int, options: RunOptions) -> None: +def run_training(run_seed: int, options: RunOptions, num_areas: int) -> None: """ Launches training session. :param run_seed: Random seed used for training. + :param num_areas: Number of training areas to instantiate :param options: parsed command line arguments """ with hierarchical_timer("run_training.setup"): @@ -95,6 +96,7 @@ def run_training(run_seed: int, options: RunOptions) -> None: env_settings.env_path, engine_settings.no_graphics, run_seed, + num_areas, port, env_settings.env_args, os.path.abspath(run_logs_dir), # Unity environment requires absolute path @@ -168,6 +170,7 @@ def create_environment_factory( env_path: Optional[str], no_graphics: bool, seed: int, + num_areas: int, start_port: Optional[int], env_args: Optional[List[str]], log_folder: str, @@ -181,6 +184,7 @@ def create_unity_environment( file_name=env_path, worker_id=worker_id, seed=env_seed, + num_areas=num_areas, no_graphics=no_graphics, base_port=start_port, additional_args=env_args, @@ -237,6 +241,7 @@ def run_cli(options: RunOptions) -> None: ) run_seed = options.env_settings.seed + num_areas = options.env_settings.num_areas # Add some timer metadata add_timer_metadata("mlagents_version", mlagents.trainers.__version__) @@ -248,7 +253,7 @@ def run_cli(options: RunOptions) -> None: if options.env_settings.seed == -1: run_seed = np.random.randint(0, 10000) logger.debug(f"run_seed set to {run_seed}") - run_training(run_seed, options) + run_training(run_seed, options, num_areas) def main(): diff --git a/ml-agents/mlagents/trainers/settings.py b/ml-agents/mlagents/trainers/settings.py index fbdf317636..fe52fb838c 100644 --- a/ml-agents/mlagents/trainers/settings.py +++ b/ml-agents/mlagents/trainers/settings.py @@ -822,6 +822,7 @@ class EnvironmentSettings: env_args: Optional[List[str]] = parser.get_default("env_args") base_port: int = parser.get_default("base_port") num_envs: int = attr.ib(default=parser.get_default("num_envs")) + num_areas: int = attr.ib(default=parser.get_default("num_areas")) seed: int = parser.get_default("seed") max_lifetime_restarts: int = parser.get_default("max_lifetime_restarts") restarts_rate_limit_n: int = parser.get_default("restarts_rate_limit_n") @@ -834,6 +835,11 @@ def validate_num_envs(self, attribute, value): if value > 1 and self.env_path is None: raise ValueError("num_envs must be 1 if env_path is not set.") + @num_areas.validator + def validate_num_area(self, attribute, value): + if value <= 0: + raise ValueError("num_areas must be set to a positive number >= 1.") + @attr.s(auto_attribs=True) class EngineSettings: diff --git a/ml-agents/mlagents/trainers/tests/test_learn.py b/ml-agents/mlagents/trainers/tests/test_learn.py index 898824c2a0..7837d59eb2 100644 --- a/ml-agents/mlagents/trainers/tests/test_learn.py +++ b/ml-agents/mlagents/trainers/tests/test_learn.py @@ -37,6 +37,7 @@ def basic_options(extra_args=None): env_settings: env_path: "./oldenvfile" num_envs: 4 + num_areas: 4 base_port: 4001 seed: 9870 checkpoint_settings: @@ -73,7 +74,7 @@ def test_run_training( with patch.object(TrainerController, "__init__", mock_init): with patch.object(TrainerController, "start_learning", MagicMock()): options = basic_options() - learn.run_training(0, options) + learn.run_training(0, options, 1) mock_init.assert_called_once_with( trainer_factory_mock.return_value, os.path.join("results", "ppo"), @@ -103,6 +104,7 @@ def test_bad_env_path(): env_path="/foo/bar", no_graphics=True, seed=-1, + num_areas=1, start_port=8000, env_args=None, log_folder="results/log_folder", @@ -126,6 +128,7 @@ def test_commandline_args(mock_file): assert opt.env_settings.seed == -1 assert opt.env_settings.base_port == 5005 assert opt.env_settings.num_envs == 1 + assert opt.env_settings.num_areas == 1 assert opt.engine_settings.no_graphics is False assert opt.debug is False assert opt.env_settings.env_args is None @@ -140,6 +143,7 @@ def test_commandline_args(mock_file): "--base-port=4004", "--initialize-from=testdir", "--num-envs=2", + "--num-areas=2", "--no-graphics", "--debug", ] @@ -152,6 +156,7 @@ def test_commandline_args(mock_file): assert opt.env_settings.seed == 7890 assert opt.env_settings.base_port == 4004 assert opt.env_settings.num_envs == 2 + assert opt.env_settings.num_areas == 2 assert opt.engine_settings.no_graphics is True assert opt.debug is True assert opt.checkpoint_settings.inference is True @@ -176,6 +181,7 @@ def test_yaml_args(mock_file): assert opt.env_settings.seed == 9870 assert opt.env_settings.base_port == 4001 assert opt.env_settings.num_envs == 4 + assert opt.env_settings.num_areas == 4 assert opt.engine_settings.no_graphics is False assert opt.debug is False assert opt.env_settings.env_args is None @@ -190,6 +196,7 @@ def test_yaml_args(mock_file): "--train", "--base-port=4004", "--num-envs=2", + "--num-areas=2", "--no-graphics", "--debug", "--results-dir=myresults", @@ -202,6 +209,7 @@ def test_yaml_args(mock_file): assert opt.env_settings.seed == 7890 assert opt.env_settings.base_port == 4004 assert opt.env_settings.num_envs == 2 + assert opt.env_settings.num_areas == 2 assert opt.engine_settings.no_graphics is True assert opt.debug is True assert opt.checkpoint_settings.inference is True diff --git a/ml-agents/mlagents/trainers/tests/test_settings.py b/ml-agents/mlagents/trainers/tests/test_settings.py index 340649b901..5fe453e43b 100644 --- a/ml-agents/mlagents/trainers/tests/test_settings.py +++ b/ml-agents/mlagents/trainers/tests/test_settings.py @@ -416,6 +416,7 @@ def test_exportable_settings(use_defaults): - test_env_args2 base_port: 12345 num_envs: 8 + num_areas: 8 seed: 12345 engine_settings: width: 12345 @@ -514,6 +515,9 @@ def test_environment_settings(): # 1 env is OK if no env_path EnvironmentSettings(num_envs=1) + # 2 areas are OK + EnvironmentSettings(num_areas=2) + # multiple envs is OK if env_path is set EnvironmentSettings(num_envs=42, env_path="/foo/bar.exe") diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/unity_rl_initialization_input.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/unity_rl_initialization_input.proto index 6db949d0c2..458e1d75d4 100644 --- a/protobuf-definitions/proto/mlagents_envs/communicator_objects/unity_rl_initialization_input.proto +++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/unity_rl_initialization_input.proto @@ -16,4 +16,7 @@ message UnityRLInitializationInputProto { // The RL Capabilities of the Python trainer. UnityRLCapabilitiesProto capabilities = 4; + + // The number of training areas to instantiate + int32 num_areas = 5; }