Skip to content

Commit

Permalink
Update DirectML samples to 1.5.0 redist (microsoft#106)
Browse files Browse the repository at this point in the history
* Update samples to support DML 1.5.0
* Add ARM/ARM64 solution configurations
* Add DML_TARGET_VERSION guards around new APIs in DirectMLX.h
* Use different intermediate folders for HelloDirectML/HelloDirectMLX to allow parallel batch build without them stomping over one another
  • Loading branch information
Adrian Tsai authored Apr 21, 2021
1 parent c545e71 commit 78d28ae
Show file tree
Hide file tree
Showing 17 changed files with 1,520 additions and 99 deletions.
40 changes: 26 additions & 14 deletions Libraries/DirectMLX.h
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,8 @@ namespace dml
return output;
}

#if DML_TARGET_VERSION >= 0x3100

inline Expression ClipGrad(Expression input, Expression inputGradient, float min, float max)
{
detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
Expand All @@ -1084,6 +1086,8 @@ namespace dml
return output;
}

#endif // DML_TARGET_VERSION >= 0x3100

inline Expression Cos(Expression input, const Optional<DML_SCALE_BIAS>& scaleBias = NullOpt)
{
return detail::ElementWiseUnary<DML_OPERATOR_ELEMENT_WISE_COS, DML_ELEMENT_WISE_COS_OPERATOR_DESC>(input, scaleBias);
Expand Down Expand Up @@ -1254,11 +1258,15 @@ namespace dml
return detail::ElementWiseUnary<DML_OPERATOR_ELEMENT_WISE_SQRT, DML_ELEMENT_WISE_SQRT_OPERATOR_DESC>(input, scaleBias);
}

#if DML_TARGET_VERSION >= 0x3100

inline Expression DifferenceSquare(Expression a, Expression b)
{
return detail::ElementWiseBinary<DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE, DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_DESC>(a, b);
}

#endif // DML_TARGET_VERSION >= 0x3100

inline Expression Subtract(Expression a, Expression b)
{
return detail::ElementWiseBinary<DML_OPERATOR_ELEMENT_WISE_SUBTRACT, DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC>(a, b);
Expand Down Expand Up @@ -2659,6 +2667,8 @@ namespace dml
return output;
}

#if DML_TARGET_VERSION >= 0x3100

struct BatchNormalizationGradOutputs
{
Expression gradient;
Expand All @@ -2684,29 +2694,31 @@ namespace dml
TensorDesc outputScaleGradientTensor(meanTensor.dataType, meanTensor.sizes, builder->GetTensorPolicy());
TensorDesc outputBiasGradientTensor(meanTensor.dataType, meanTensor.sizes, builder->GetTensorPolicy());

DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC bng_desc = {};
bng_desc.InputTensor = inputTensor.AsPtr<DML_TENSOR_DESC>();
bng_desc.InputGradientTensor = inputGradientTensor.AsPtr<DML_TENSOR_DESC>();
bng_desc.MeanTensor = meanTensor.AsPtr<DML_TENSOR_DESC>();
bng_desc.VarianceTensor = varianceTensor.AsPtr<DML_TENSOR_DESC>();
bng_desc.ScaleTensor = scaleTensor.AsPtr<DML_TENSOR_DESC>();
bng_desc.Epsilon = epsilon;
DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC desc = {};
desc.InputTensor = inputTensor.AsPtr<DML_TENSOR_DESC>();
desc.InputGradientTensor = inputGradientTensor.AsPtr<DML_TENSOR_DESC>();
desc.MeanTensor = meanTensor.AsPtr<DML_TENSOR_DESC>();
desc.VarianceTensor = varianceTensor.AsPtr<DML_TENSOR_DESC>();
desc.ScaleTensor = scaleTensor.AsPtr<DML_TENSOR_DESC>();
desc.Epsilon = epsilon;

bng_desc.OutputGradientTensor = outputGradientTensor.AsPtr<DML_TENSOR_DESC>();
bng_desc.OutputScaleGradientTensor = outputScaleGradientTensor.AsPtr<DML_TENSOR_DESC>();
bng_desc.OutputBiasGradientTensor = outputBiasGradientTensor.AsPtr<DML_TENSOR_DESC>();
desc.OutputGradientTensor = outputGradientTensor.AsPtr<DML_TENSOR_DESC>();
desc.OutputScaleGradientTensor = outputScaleGradientTensor.AsPtr<DML_TENSOR_DESC>();
desc.OutputBiasGradientTensor = outputBiasGradientTensor.AsPtr<DML_TENSOR_DESC>();

dml::detail::NodeOutput* const inputs[] = { input.Impl(), inputGradient.Impl(), mean.Impl(), variance.Impl(), scale.Impl() };
dml::detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_BATCH_NORMALIZATION_GRAD, &bng_desc, inputs);
dml::detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_BATCH_NORMALIZATION_GRAD, &desc, inputs);

BatchNormalizationGradOutputs outputValues;
outputValues.gradient = builder->CreateNodeOutput(node, 0, *bng_desc.OutputGradientTensor);
outputValues.scaleGradient = builder->CreateNodeOutput(node, 1, *bng_desc.OutputScaleGradientTensor);
outputValues.biasGradient = builder->CreateNodeOutput(node, 2, *bng_desc.OutputBiasGradientTensor);
outputValues.gradient = builder->CreateNodeOutput(node, 0, *desc.OutputGradientTensor);
outputValues.scaleGradient = builder->CreateNodeOutput(node, 1, *desc.OutputScaleGradientTensor);
outputValues.biasGradient = builder->CreateNodeOutput(node, 2, *desc.OutputBiasGradientTensor);

return outputValues;
}

#endif // DML_TARGET_VERSION >= 0x3100

inline Expression MeanVarianceNormalization(
Expression input,
Optional<Expression> scale,
Expand Down
28 changes: 28 additions & 0 deletions Samples/DirectMLSuperResolution/DirectMLSuperResolution.sln
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,60 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "DirectMLXSuperResolution",
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|ARM = Debug|ARM
Debug|ARM64 = Debug|ARM64
Debug|x64 = Debug|x64
Debug|x86 = Debug|x86
Release|ARM = Release|ARM
Release|ARM64 = Release|ARM64
Release|x64 = Release|x64
Release|x86 = Release|x86
EndGlobalSection
GlobalSection(ProjectConfigurationPlatforms) = postSolution
{70CDBD87-F286-4EE0-87F1-5A1D09396CDA}.Debug|ARM.ActiveCfg = Debug|ARM
{70CDBD87-F286-4EE0-87F1-5A1D09396CDA}.Debug|ARM.Build.0 = Debug|ARM
{70CDBD87-F286-4EE0-87F1-5A1D09396CDA}.Debug|ARM64.ActiveCfg = Debug|ARM64
{70CDBD87-F286-4EE0-87F1-5A1D09396CDA}.Debug|ARM64.Build.0 = Debug|ARM64
{70CDBD87-F286-4EE0-87F1-5A1D09396CDA}.Debug|x64.ActiveCfg = Debug|x64
{70CDBD87-F286-4EE0-87F1-5A1D09396CDA}.Debug|x64.Build.0 = Debug|x64
{70CDBD87-F286-4EE0-87F1-5A1D09396CDA}.Debug|x86.ActiveCfg = Debug|Win32
{70CDBD87-F286-4EE0-87F1-5A1D09396CDA}.Debug|x86.Build.0 = Debug|Win32
{70CDBD87-F286-4EE0-87F1-5A1D09396CDA}.Release|ARM.ActiveCfg = Release|ARM
{70CDBD87-F286-4EE0-87F1-5A1D09396CDA}.Release|ARM.Build.0 = Release|ARM
{70CDBD87-F286-4EE0-87F1-5A1D09396CDA}.Release|ARM64.ActiveCfg = Release|ARM64
{70CDBD87-F286-4EE0-87F1-5A1D09396CDA}.Release|ARM64.Build.0 = Release|ARM64
{70CDBD87-F286-4EE0-87F1-5A1D09396CDA}.Release|x64.ActiveCfg = Release|x64
{70CDBD87-F286-4EE0-87F1-5A1D09396CDA}.Release|x64.Build.0 = Release|x64
{70CDBD87-F286-4EE0-87F1-5A1D09396CDA}.Release|x86.ActiveCfg = Release|Win32
{70CDBD87-F286-4EE0-87F1-5A1D09396CDA}.Release|x86.Build.0 = Release|Win32
{3E0E8608-CD9B-4C76-AF33-29CA38F2C9F0}.Debug|ARM.ActiveCfg = Debug|ARM
{3E0E8608-CD9B-4C76-AF33-29CA38F2C9F0}.Debug|ARM.Build.0 = Debug|ARM
{3E0E8608-CD9B-4C76-AF33-29CA38F2C9F0}.Debug|ARM64.ActiveCfg = Debug|ARM64
{3E0E8608-CD9B-4C76-AF33-29CA38F2C9F0}.Debug|ARM64.Build.0 = Debug|ARM64
{3E0E8608-CD9B-4C76-AF33-29CA38F2C9F0}.Debug|x64.ActiveCfg = Debug|x64
{3E0E8608-CD9B-4C76-AF33-29CA38F2C9F0}.Debug|x64.Build.0 = Debug|x64
{3E0E8608-CD9B-4C76-AF33-29CA38F2C9F0}.Debug|x86.ActiveCfg = Debug|Win32
{3E0E8608-CD9B-4C76-AF33-29CA38F2C9F0}.Debug|x86.Build.0 = Debug|Win32
{3E0E8608-CD9B-4C76-AF33-29CA38F2C9F0}.Release|ARM.ActiveCfg = Release|ARM
{3E0E8608-CD9B-4C76-AF33-29CA38F2C9F0}.Release|ARM.Build.0 = Release|ARM
{3E0E8608-CD9B-4C76-AF33-29CA38F2C9F0}.Release|ARM64.ActiveCfg = Release|ARM64
{3E0E8608-CD9B-4C76-AF33-29CA38F2C9F0}.Release|ARM64.Build.0 = Release|ARM64
{3E0E8608-CD9B-4C76-AF33-29CA38F2C9F0}.Release|x64.ActiveCfg = Release|x64
{3E0E8608-CD9B-4C76-AF33-29CA38F2C9F0}.Release|x64.Build.0 = Release|x64
{3E0E8608-CD9B-4C76-AF33-29CA38F2C9F0}.Release|x86.ActiveCfg = Release|Win32
{3E0E8608-CD9B-4C76-AF33-29CA38F2C9F0}.Release|x86.Build.0 = Release|Win32
{31C25314-96AE-4EF0-84B7-0026C14F12AD}.Debug|ARM.ActiveCfg = Debug|ARM
{31C25314-96AE-4EF0-84B7-0026C14F12AD}.Debug|ARM.Build.0 = Debug|ARM
{31C25314-96AE-4EF0-84B7-0026C14F12AD}.Debug|ARM64.ActiveCfg = Debug|ARM64
{31C25314-96AE-4EF0-84B7-0026C14F12AD}.Debug|ARM64.Build.0 = Debug|ARM64
{31C25314-96AE-4EF0-84B7-0026C14F12AD}.Debug|x64.ActiveCfg = Debug|x64
{31C25314-96AE-4EF0-84B7-0026C14F12AD}.Debug|x64.Build.0 = Debug|x64
{31C25314-96AE-4EF0-84B7-0026C14F12AD}.Debug|x86.ActiveCfg = Debug|Win32
{31C25314-96AE-4EF0-84B7-0026C14F12AD}.Debug|x86.Build.0 = Debug|Win32
{31C25314-96AE-4EF0-84B7-0026C14F12AD}.Release|ARM.ActiveCfg = Release|ARM
{31C25314-96AE-4EF0-84B7-0026C14F12AD}.Release|ARM.Build.0 = Release|ARM
{31C25314-96AE-4EF0-84B7-0026C14F12AD}.Release|ARM64.ActiveCfg = Release|ARM64
{31C25314-96AE-4EF0-84B7-0026C14F12AD}.Release|ARM64.Build.0 = Release|ARM64
{31C25314-96AE-4EF0-84B7-0026C14F12AD}.Release|x64.ActiveCfg = Release|x64
{31C25314-96AE-4EF0-84B7-0026C14F12AD}.Release|x64.Build.0 = Release|x64
{31C25314-96AE-4EF0-84B7-0026C14F12AD}.Release|x86.ActiveCfg = Release|Win32
Expand Down
Loading

0 comments on commit 78d28ae

Please sign in to comment.