diff --git a/metallurgy/generate.py b/metallurgy/generate.py index 1c3121b..da80ad3 100644 --- a/metallurgy/generate.py +++ b/metallurgy/generate.py @@ -535,3 +535,63 @@ def quaternary( alloys.append(ternary_alloys) return alloys + + +def perturb(alloy, size=0.05): + if isinstance(alloy, list): + return [perturb(a) for a in alloy] + + composition = dict(alloy.composition) + structure = alloy.structure + constraints = alloy.constraints + if "local_percentages" in constraints: + del constraints["local_percentages"] + if "digits" in constraints: + del constraints["digits"] + + for element in composition: + composition[element] += round( + float(np.random.random(1) * 2 - 1) * size, 2 + ) + composition[element] = max(composition[element], 0) + + elements = list(composition.keys()) + if len(composition) > 1: + for element in elements: + if composition[element] < 0.0001: + if ( + constraints is not None + and element in constraints["percentages"] + and constraints["percentages"][element]["precedence"] > 0 + ): + continue + del composition[element] + + if np.random.random(1) > 0.1: + allowed_elements = list(mg.periodic_table.elements.keys()) + if constraints is not None and "allowed_elements" in constraints: + allowed_elements = constraints["allowed_elements"][:] + for element in composition: + if element in allowed_elements: + allowed_elements.remove(element) + if len(allowed_elements) > 0: + composition[np.random.choice(allowed_elements)] = round( + float(np.random.random(1)) * size, 2 + ) + + elif len(composition) > 1 and np.random.random(1) > 0.1: + to_delete = np.random.choice(list(composition.keys())) + if not ( + constraints is not None + and to_delete in constraints["percentages"] + and constraints["percentages"][to_delete]["precedence"] > 0 + ): + del composition[to_delete] + + if structure is not None and np.random.random(1) > 0.1: + structure = get_random_prototype() + + new_alloy = mg.Alloy( + composition, constraints=constraints, structure=structure + ) + return new_alloy