Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kokkos ensure kokkos function #16

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ autoconf/autom4te.cache
/CMakeSettings.json
# CLion project configuration
/.idea
/.cache

#==============================================================================#
# Directories to ignore (do not add trailing '/'s, they skip symlinks).
Expand Down
1 change: 1 addition & 0 deletions clang-tools-extra/clang-tidy/kokkos/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
set(LLVM_LINK_COMPONENTS support)

add_clang_library(clangTidyKokkosModule
EnsureKokkosFunctionCheck.cpp
ImplicitThisCaptureCheck.cpp
KokkosMatchers.cpp
KokkosTidyModule.cpp
Expand Down
220 changes: 220 additions & 0 deletions clang-tools-extra/clang-tidy/kokkos/EnsureKokkosFunctionCheck.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
//===--- EnsureKokkosFunctionCheck.cpp - clang-tidy -----------------------===//
//
// Copyright 2020 National Technology & Engineering Solutions of Sandia,
// LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S.
// Government retains certain rights in this software.
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "EnsureKokkosFunctionCheck.h"
#include "KokkosMatchers.h"
#include "clang/AST/ASTContext.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"

using namespace clang::ast_matchers;

namespace clang {
namespace tidy {
namespace kokkos {
namespace {

std::string KF_Regex = "KOKKOS_.*FUNCTION"; // NOLINT

auto notKFunc(std::string const &AllowedFuncRegex) {
auto AllowedFuncMatch = unless(matchesName(AllowedFuncRegex));
return functionDecl(unless(matchesAttr(KF_Regex)),
unless(isExpansionInSystemHeader()), AllowedFuncMatch);
}

bool isAnnotated(CXXMethodDecl const *Method) {
// If the method is annotated the match will not be empty
return !match(cxxMethodDecl(matchesAttr(KF_Regex)), *Method,
Method->getASTContext())
.empty();
}

// TODO one day we might want to check if the lambda is local to our current
// function context, but until someone complains that's a lot of work. The
// other case we aren't going to deal with is: void foo(){ struct S { static
// void func(){} }; S::func(); }
bool callExprIsToLambaOp(CallExpr const *CE) {
if (auto const *CMD =
dyn_cast_or_null<CXXMethodDecl>(CE->getDirectCallee())) {
if (auto const *Parent = CMD->getParent()) {
if (Parent->isLambda()) {
return true;
}
}
}
return false;
}

auto checkLambdaBody(CXXRecordDecl const *Lambda,
std::string const &AllowedFuncRegex) {
assert(Lambda->isLambda());
llvm::SmallPtrSet<CallExpr const *, 1> BadCallSet;
auto const *FD = Lambda->getLambdaCallOperator();
if (!FD) {
return BadCallSet;
}

auto notKCalls = // NOLINT
callExpr(callee(notKFunc(AllowedFuncRegex))).bind("CE");

auto BadCalls = match(functionDecl(forEachDescendant(notKCalls)), *FD,
FD->getASTContext());

for (auto BadCall : BadCalls) {
auto const *CE = BadCall.getNodeAs<CallExpr>("CE");
if (callExprIsToLambaOp(CE)) { // function call handles nullptr
continue;
}

BadCallSet.insert(CE);
}

return BadCallSet;
}

// Recurses through the tree of all calls to functions with visble bodies
void recurseCallExpr(
llvm::SmallPtrSet<CXXMethodDecl const *, 8> const &FunctorMethods,
CallExpr const *Call,
llvm::SmallPtrSet<CXXMethodDecl const *, 4> &Results) {

// Get the body of the called function
auto const *CallDecl = Call->getCalleeDecl();
if (CallDecl == nullptr || !CallDecl->hasBody()) {
return;
}

auto &ASTContext = CallDecl->getASTContext();

// Check if the called function is a member function of the functor
// if yes then write the result back out.
if (auto const *Method = dyn_cast<CXXMethodDecl>(CallDecl)) {
if (FunctorMethods.count(Method) > 0) {
Results.insert(Method);
}
}

// Match all callexprs in our body
auto CEs = match(compoundStmt(forEachDescendant(callExpr().bind("CE"))),
*(CallDecl->getBody()), ASTContext);

// Check all those calls for uses of members of the functor as well
for (auto BN : CEs) {
if (auto const *CE = BN.getNodeAs<CallExpr>("CE")) {
recurseCallExpr(FunctorMethods, CE, Results);
}
}
}

// Find methods from our functor called in the tree of Kokkos::parallel_x
auto checkFunctorBody(CXXRecordDecl const *Functor, CallExpr const *CallSite) {
llvm::SmallPtrSet<CXXMethodDecl const *, 8> FunctorMethods;
for (auto const *Method : Functor->methods()) {
FunctorMethods.insert(Method);
}
llvm::SmallPtrSet<CXXMethodDecl const *, 4> Results;
recurseCallExpr(FunctorMethods, CallSite, Results);

return Results;
}

} // namespace

EnsureKokkosFunctionCheck::EnsureKokkosFunctionCheck(StringRef Name,
ClangTidyContext *Context)
: ClangTidyCheck(Name, Context) {
AllowIfExplicitHost = std::stoi(Options.get("AllowIfExplicitHost", "0"));
AllowedFunctionsRegex = Options.get("AllowedFunctionsRegex", "a^");
// This can't be empty because the regex ast matchers assert !empty
assert(!AllowedFunctionsRegex.empty());
}

void EnsureKokkosFunctionCheck::storeOptions(
ClangTidyOptions::OptionMap &Opts) {
Options.store(Opts, "AllowedFunctionsRegex", AllowedFunctionsRegex);
Options.store(Opts, "AllowIfExplicitHost",
std::to_string(AllowIfExplicitHost));
}

void EnsureKokkosFunctionCheck::registerMatchers(MatchFinder *Finder) {
auto notKCalls = // NOLINT
callExpr(callee(notKFunc(AllowedFunctionsRegex))).bind("CE");

// We have to be sure that we don't match functionDecls in systems headers,
// because they might call our Functor, which if it is a lambda will not be
// marked with KOKKOS_FUNCITON
calewis marked this conversation as resolved.
Show resolved Hide resolved
Finder->addMatcher(functionDecl(matchesAttr(KF_Regex),
unless(isExpansionInSystemHeader()),
forEachDescendant(notKCalls))
.bind("ParentFD"),
this);

// Need to check the Functor also
auto Functor = expr(hasType(
cxxRecordDecl(unless(isExpansionInSystemHeader())).bind("Functor")));
Finder->addMatcher(callExpr(isKokkosParallelCall(), hasAnyArgument(Functor))
.bind("KokkosCE"),
this);
}

void EnsureKokkosFunctionCheck::check(const MatchFinder::MatchResult &Result) {

auto const *ParentFD = Result.Nodes.getNodeAs<FunctionDecl>("ParentFD");
auto const *CE = Result.Nodes.getNodeAs<CallExpr>("CE");
auto const *Functor = Result.Nodes.getNodeAs<CXXRecordDecl>("Functor");

if (ParentFD != nullptr) {
if (callExprIsToLambaOp(CE)) { // Avoid false positives for local lambdas
return;
}

diag(CE->getBeginLoc(),
"function %0 called in %1 is missing a KOKKOS_X_FUNCTION annotation")
<< CE->getDirectCallee() << ParentFD;
diag(CE->getDirectCallee()->getLocation(), "Function %0 declared here",
DiagnosticIDs::Note)
<< CE->getDirectCallee();
}

if (Functor != nullptr) {
auto const *CE = Result.Nodes.getNodeAs<CallExpr>("KokkosCE");
if (AllowIfExplicitHost != 0 && explicitlyDefaultHostExecutionSpace(CE)) {
return;
}

if (Functor->isLambda()) {
auto BadCalls = checkLambdaBody(Functor, AllowedFunctionsRegex);
for (auto const *BadCall : BadCalls) {
diag(BadCall->getBeginLoc(),
"Function %0 called in a lambda was missing "
"KOKKOS_X_FUNCTION annotation.")
<< BadCall->getDirectCallee();
diag(BadCall->getDirectCallee()->getBeginLoc(),
"Function %0 was delcared here", DiagnosticIDs::Note)
<< BadCall->getDirectCallee();
}
} else {
for (auto const *CalledMethod : checkFunctorBody(Functor, CE)) {
if (isAnnotated(CalledMethod)) {
continue;
}

diag(CalledMethod->getBeginLoc(), "Member Function of %0, requires a "
"KOKKOS_X_FUNCTION annotation.")
<< CalledMethod->getParent();
}
}
}
}

} // namespace kokkos
} // namespace tidy
} // namespace clang
39 changes: 39 additions & 0 deletions clang-tools-extra/clang-tidy/kokkos/EnsureKokkosFunctionCheck.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//===--- EnsureKokkosFunctionCheck.h - clang-tidy ---------------*- C++ -*-===//
//
// Copyright 2020 National Technology & Engineering Solutions of Sandia,
// LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S.
// Government retains certain rights in this software.
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_CLANG_TOOLS_EXTRA_CLANG_TIDY_KOKKOS_ENSUREKOKKOSFUNCTIONCHECK_H
#define LLVM_CLANG_TOOLS_EXTRA_CLANG_TIDY_KOKKOS_ENSUREKOKKOSFUNCTIONCHECK_H

#include "../ClangTidyCheck.h"

namespace clang {
namespace tidy {
namespace kokkos {

/// Check that ensures user provided functions were properly annotated
class EnsureKokkosFunctionCheck : public ClangTidyCheck {
public:
EnsureKokkosFunctionCheck(StringRef Name, ClangTidyContext *Context);
void registerMatchers(ast_matchers::MatchFinder *Finder) override;
void check(const ast_matchers::MatchFinder::MatchResult &Result) override;
void storeOptions(ClangTidyOptions::OptionMap &Opts) override;

private:
std::string AllowedFunctionsRegex;
int AllowIfExplicitHost;
};

} // namespace kokkos
} // namespace tidy
} // namespace clang

#endif // LLVM_CLANG_TOOLS_EXTRA_CLANG_TIDY_KOKKOS_ENSUREKOKKOSFUNCTIONCHECK_H
13 changes: 5 additions & 8 deletions clang-tools-extra/clang-tidy/kokkos/ImplicitThisCaptureCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,12 @@ llvm::Optional<SourceLocation> capturesThis(CXXRecordDecl const *CRD) {
ImplicitThisCaptureCheck::ImplicitThisCaptureCheck(StringRef Name,
ClangTidyContext *Context)
: ClangTidyCheck(Name, Context) {
CheckIfExplicitHost = std::stoi(Options.get("CheckIfExplicitHost", "0"));
HostTypeDefRegex =
Options.get("HostTypeDefRegex", "Kokkos::DefaultHostExecutionSpace");
AllowIfExplicitHost = std::stoi(Options.get("AllowIfExplicitHost", "0"));
}

void ImplicitThisCaptureCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
Options.store(Opts, "CheckIfExplicitHost",
std::to_string(CheckIfExplicitHost));
Options.store(Opts, "HostTypeDefRegex", HostTypeDefRegex);
Options.store(Opts, "AllowIfExplicitHost",
std::to_string(AllowIfExplicitHost));
}

void ImplicitThisCaptureCheck::registerMatchers(MatchFinder *Finder) {
Expand All @@ -69,8 +66,8 @@ void ImplicitThisCaptureCheck::registerMatchers(MatchFinder *Finder) {
void ImplicitThisCaptureCheck::check(const MatchFinder::MatchResult &Result) {
auto const *CE = Result.Nodes.getNodeAs<CallExpr>("x");

if (CheckIfExplicitHost) {
if (explicitlyUsingHostExecutionSpace(CE, HostTypeDefRegex)) {
if (AllowIfExplicitHost != 0) {
if (explicitlyDefaultHostExecutionSpace(CE)) {
return;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ class ImplicitThisCaptureCheck : public ClangTidyCheck {
void storeOptions(ClangTidyOptions::OptionMap &Opts) override;
void check(const ast_matchers::MatchFinder::MatchResult &Result) override;
private:
int CheckIfExplicitHost;
std::string HostTypeDefRegex;
int AllowIfExplicitHost;
};

} // namespace kokkos
Expand Down
55 changes: 44 additions & 11 deletions clang-tools-extra/clang-tidy/kokkos/KokkosMatchers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,63 @@ namespace clang {
namespace tidy {
namespace kokkos {

bool explicitlyUsingHostExecutionSpace(CallExpr const *CE,
std::string const &RegexString) {
namespace {
TypedefNameDecl const *getTypedefFromFirstTemplateArg(Expr const *E) {
if (E == nullptr) {
return nullptr;
}

auto const *TST = E->getType()->getAs<TemplateSpecializationType>();
if (TST == nullptr) {
return nullptr;
}
if (TST->getNumArgs() < 1) {
return nullptr;
}

auto const *TDT = TST->getArg(0).getAsType()->getAs<TypedefType>();
if (TDT == nullptr) {
return nullptr;
}

auto const *TDD = dyn_cast_or_null<TypedefNameDecl>(TDT->getDecl());
return TDD;
}

bool isMatchingAnnotation(Attr const *At, std::string const &target) {
if (auto const *Anna = dyn_cast<AnnotateAttr>(At)) {
if (Anna->getAnnotation() == target) {
return true;
}
}

return false;
}
} // namespace

bool explicitlyDefaultHostExecutionSpace(CallExpr const *CE) {
using namespace clang::ast_matchers;
auto &Ctx = CE->getCalleeDecl()->getASTContext();

// We will assume that any policy where the user might explicitly ask for the
// host space inherits from Impl::PolicyTraits
auto FilterArgs =
hasAnyArgument(expr(hasType(cxxRecordDecl(isDerivedFrom(cxxRecordDecl(
matchesName("Impl::PolicyTraits"))))))
.bind("expr"));
auto FilterArgs = hasAnyArgument(
expr(hasType(classTemplateSpecializationDecl(isDerivedFrom(
cxxRecordDecl(matchesName("Impl::PolicyTraits"))))))
.bind("expr"));

// We have to jump through some hoops to find this, if we just looked at the
// template type of the Policy constructor we lose the sugar and instead of
// Kokkos::DefaultHostExecutionSpace we get what the ever the typedef was set
// to such as Kokkos::Serial, preventing us from figuring out if the user
// actually asked for a host space specifically or just happens to have a
// host space as the default space.
llvm::Regex Reg(RegexString);
auto BNs = match(callExpr(FilterArgs).bind("CE"), *CE, Ctx);
auto BNs = match(callExpr(FilterArgs), *CE, Ctx);
for (auto &BN : BNs) {
if (auto const *E = BN.getNodeAs<Expr>("expr")) {
if (auto const *TST = E->getType()->getAs<TemplateSpecializationType>()) {
if (Reg.match(TST->getArg(0).getAsType().getAsString())) {
if (auto const *TDD =
getTypedefFromFirstTemplateArg(BN.getNodeAs<Expr>("expr"))) {
for (auto const *At : TDD->attrs()) {
if (isMatchingAnnotation(At, "DefaultHostExecutionSpace")) {
return true;
}
}
Expand Down
Loading