Skip to content
Snippets Groups Projects
Commit 37d785b4 authored by Justin Bogner's avatar Justin Bogner
Browse files

InstrProf: Use a locally tracked current count in ComputeRegionCounts

No functional change. This just makes it more obvious that the logic
in ComputeRegionCounts only depends on the counter map and local
state.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@236370 91177308-0d34-0410-b5e6-96231b3b80d8
parent 72289675
No related branches found
No related tags found
No related merge requests found
...@@ -242,6 +242,9 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -242,6 +242,9 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
/// next statement, such as at the exit of a loop. /// next statement, such as at the exit of a loop.
bool RecordNextStmtCount; bool RecordNextStmtCount;
/// The count at the current location in the traversal.
uint64_t CurrentCount;
/// The map of statements to count values. /// The map of statements to count values.
llvm::DenseMap<const Stmt *, uint64_t> &CountMap; llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
...@@ -259,14 +262,14 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -259,14 +262,14 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
void RecordStmtCount(const Stmt *S) { void RecordStmtCount(const Stmt *S) {
if (RecordNextStmtCount) { if (RecordNextStmtCount) {
CountMap[S] = PGO.getCurrentRegionCount(); CountMap[S] = CurrentCount;
RecordNextStmtCount = false; RecordNextStmtCount = false;
} }
} }
/// Set and return the current count. /// Set and return the current count.
uint64_t setCount(uint64_t Count) { uint64_t setCount(uint64_t Count) {
PGO.setCurrentRegionCount(Count); CurrentCount = Count;
return Count; return Count;
} }
...@@ -315,7 +318,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -315,7 +318,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
RecordStmtCount(S); RecordStmtCount(S);
if (S->getRetValue()) if (S->getRetValue())
Visit(S->getRetValue()); Visit(S->getRetValue());
PGO.setCurrentRegionUnreachable(); CurrentCount = 0;
RecordNextStmtCount = true; RecordNextStmtCount = true;
} }
...@@ -323,13 +326,13 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -323,13 +326,13 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
RecordStmtCount(E); RecordStmtCount(E);
if (E->getSubExpr()) if (E->getSubExpr())
Visit(E->getSubExpr()); Visit(E->getSubExpr());
PGO.setCurrentRegionUnreachable(); CurrentCount = 0;
RecordNextStmtCount = true; RecordNextStmtCount = true;
} }
void VisitGotoStmt(const GotoStmt *S) { void VisitGotoStmt(const GotoStmt *S) {
RecordStmtCount(S); RecordStmtCount(S);
PGO.setCurrentRegionUnreachable(); CurrentCount = 0;
RecordNextStmtCount = true; RecordNextStmtCount = true;
} }
...@@ -344,30 +347,30 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -344,30 +347,30 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
void VisitBreakStmt(const BreakStmt *S) { void VisitBreakStmt(const BreakStmt *S) {
RecordStmtCount(S); RecordStmtCount(S);
assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount(); BreakContinueStack.back().BreakCount += CurrentCount;
PGO.setCurrentRegionUnreachable(); CurrentCount = 0;
RecordNextStmtCount = true; RecordNextStmtCount = true;
} }
void VisitContinueStmt(const ContinueStmt *S) { void VisitContinueStmt(const ContinueStmt *S) {
RecordStmtCount(S); RecordStmtCount(S);
assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount(); BreakContinueStack.back().ContinueCount += CurrentCount;
PGO.setCurrentRegionUnreachable(); CurrentCount = 0;
RecordNextStmtCount = true; RecordNextStmtCount = true;
} }
void VisitWhileStmt(const WhileStmt *S) { void VisitWhileStmt(const WhileStmt *S) {
RecordStmtCount(S); RecordStmtCount(S);
uint64_t ParentCount = PGO.getCurrentRegionCount(); uint64_t ParentCount = CurrentCount;
BreakContinueStack.push_back(BreakContinue()); BreakContinueStack.push_back(BreakContinue());
// Visit the body region first so the break/continue adjustments can be // Visit the body region first so the break/continue adjustments can be
// included when visiting the condition. // included when visiting the condition.
uint64_t BodyCount = setCount(PGO.getRegionCount(S)); uint64_t BodyCount = setCount(PGO.getRegionCount(S));
CountMap[S->getBody()] = PGO.getCurrentRegionCount(); CountMap[S->getBody()] = CurrentCount;
Visit(S->getBody()); Visit(S->getBody());
uint64_t BackedgeCount = PGO.getCurrentRegionCount(); uint64_t BackedgeCount = CurrentCount;
// ...then go back and propagate counts through the condition. The count // ...then go back and propagate counts through the condition. The count
// at the start of the condition is the sum of the incoming edges, // at the start of the condition is the sum of the incoming edges,
...@@ -388,10 +391,10 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -388,10 +391,10 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
BreakContinueStack.push_back(BreakContinue()); BreakContinueStack.push_back(BreakContinue());
// The count doesn't include the fallthrough from the parent scope. Add it. // The count doesn't include the fallthrough from the parent scope. Add it.
uint64_t BodyCount = setCount(LoopCount + PGO.getCurrentRegionCount()); uint64_t BodyCount = setCount(LoopCount + CurrentCount);
CountMap[S->getBody()] = BodyCount; CountMap[S->getBody()] = BodyCount;
Visit(S->getBody()); Visit(S->getBody());
uint64_t BackedgeCount = PGO.getCurrentRegionCount(); uint64_t BackedgeCount = CurrentCount;
BreakContinue BC = BreakContinueStack.pop_back_val(); BreakContinue BC = BreakContinueStack.pop_back_val();
// The count at the start of the condition is equal to the count at the // The count at the start of the condition is equal to the count at the
...@@ -408,7 +411,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -408,7 +411,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
if (S->getInit()) if (S->getInit())
Visit(S->getInit()); Visit(S->getInit());
uint64_t ParentCount = PGO.getCurrentRegionCount(); uint64_t ParentCount = CurrentCount;
BreakContinueStack.push_back(BreakContinue()); BreakContinueStack.push_back(BreakContinue());
// Visit the body region first. (This is basically the same as a while // Visit the body region first. (This is basically the same as a while
...@@ -416,7 +419,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -416,7 +419,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
uint64_t BodyCount = setCount(PGO.getRegionCount(S)); uint64_t BodyCount = setCount(PGO.getRegionCount(S));
CountMap[S->getBody()] = BodyCount; CountMap[S->getBody()] = BodyCount;
Visit(S->getBody()); Visit(S->getBody());
uint64_t BackedgeCount = PGO.getCurrentRegionCount(); uint64_t BackedgeCount = CurrentCount;
BreakContinue BC = BreakContinueStack.pop_back_val(); BreakContinue BC = BreakContinueStack.pop_back_val();
// The increment is essentially part of the body but it needs to include // The increment is essentially part of the body but it needs to include
...@@ -444,15 +447,14 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -444,15 +447,14 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
Visit(S->getRangeStmt()); Visit(S->getRangeStmt());
Visit(S->getBeginEndStmt()); Visit(S->getBeginEndStmt());
uint64_t ParentCount = PGO.getCurrentRegionCount(); uint64_t ParentCount = CurrentCount;
BreakContinueStack.push_back(BreakContinue()); BreakContinueStack.push_back(BreakContinue());
// Visit the body region first. (This is basically the same as a while // Visit the body region first. (This is basically the same as a while
// loop; see further comments in VisitWhileStmt.) // loop; see further comments in VisitWhileStmt.)
uint64_t BodyCount = setCount(PGO.getRegionCount(S)); uint64_t BodyCount = setCount(PGO.getRegionCount(S));
CountMap[S->getBody()] = BodyCount; CountMap[S->getBody()] = BodyCount;
Visit(S->getBody()); Visit(S->getBody());
uint64_t BackedgeCount = PGO.getCurrentRegionCount(); uint64_t BackedgeCount = CurrentCount;
BreakContinue BC = BreakContinueStack.pop_back_val(); BreakContinue BC = BreakContinueStack.pop_back_val();
// The increment is essentially part of the body but it needs to include // The increment is essentially part of the body but it needs to include
...@@ -473,13 +475,13 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -473,13 +475,13 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
RecordStmtCount(S); RecordStmtCount(S);
Visit(S->getElement()); Visit(S->getElement());
uint64_t ParentCount = PGO.getCurrentRegionCount(); uint64_t ParentCount = CurrentCount;
BreakContinueStack.push_back(BreakContinue()); BreakContinueStack.push_back(BreakContinue());
// Counter tracks the body of the loop. // Counter tracks the body of the loop.
uint64_t BodyCount = setCount(PGO.getRegionCount(S)); uint64_t BodyCount = setCount(PGO.getRegionCount(S));
CountMap[S->getBody()] = BodyCount; CountMap[S->getBody()] = BodyCount;
Visit(S->getBody()); Visit(S->getBody());
uint64_t BackedgeCount = PGO.getCurrentRegionCount(); uint64_t BackedgeCount = CurrentCount;
BreakContinue BC = BreakContinueStack.pop_back_val(); BreakContinue BC = BreakContinueStack.pop_back_val();
setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
...@@ -490,7 +492,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -490,7 +492,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
void VisitSwitchStmt(const SwitchStmt *S) { void VisitSwitchStmt(const SwitchStmt *S) {
RecordStmtCount(S); RecordStmtCount(S);
Visit(S->getCond()); Visit(S->getCond());
PGO.setCurrentRegionUnreachable(); CurrentCount = 0;
BreakContinueStack.push_back(BreakContinue()); BreakContinueStack.push_back(BreakContinue());
Visit(S->getBody()); Visit(S->getBody());
// If the switch is inside a loop, add the continue counts. // If the switch is inside a loop, add the continue counts.
...@@ -508,7 +510,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -508,7 +510,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
// switch header and does not include fallthrough from the case before // switch header and does not include fallthrough from the case before
// this one. // this one.
uint64_t CaseCount = PGO.getRegionCount(S); uint64_t CaseCount = PGO.getRegionCount(S);
setCount(PGO.getCurrentRegionCount() + CaseCount); setCount(CurrentCount + CaseCount);
// We need the count without fallthrough in the mapping, so it's more useful // We need the count without fallthrough in the mapping, so it's more useful
// for branch probabilities. // for branch probabilities.
CountMap[S] = CaseCount; CountMap[S] = CaseCount;
...@@ -518,7 +520,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -518,7 +520,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
void VisitIfStmt(const IfStmt *S) { void VisitIfStmt(const IfStmt *S) {
RecordStmtCount(S); RecordStmtCount(S);
uint64_t ParentCount = PGO.getCurrentRegionCount(); uint64_t ParentCount = CurrentCount;
Visit(S->getCond()); Visit(S->getCond());
// Counter tracks the "then" part of an if statement. The count for // Counter tracks the "then" part of an if statement. The count for
...@@ -526,14 +528,14 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -526,14 +528,14 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
uint64_t ThenCount = setCount(PGO.getRegionCount(S)); uint64_t ThenCount = setCount(PGO.getRegionCount(S));
CountMap[S->getThen()] = ThenCount; CountMap[S->getThen()] = ThenCount;
Visit(S->getThen()); Visit(S->getThen());
uint64_t OutCount = PGO.getCurrentRegionCount(); uint64_t OutCount = CurrentCount;
uint64_t ElseCount = ParentCount - ThenCount; uint64_t ElseCount = ParentCount - ThenCount;
if (S->getElse()) { if (S->getElse()) {
setCount(ElseCount); setCount(ElseCount);
CountMap[S->getElse()] = ElseCount; CountMap[S->getElse()] = ElseCount;
Visit(S->getElse()); Visit(S->getElse());
OutCount += PGO.getCurrentRegionCount(); OutCount += CurrentCount;
} else } else
OutCount += ElseCount; OutCount += ElseCount;
setCount(OutCount); setCount(OutCount);
...@@ -560,7 +562,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -560,7 +562,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
RecordStmtCount(E); RecordStmtCount(E);
uint64_t ParentCount = PGO.getCurrentRegionCount(); uint64_t ParentCount = CurrentCount;
Visit(E->getCond()); Visit(E->getCond());
// Counter tracks the "true" part of a conditional operator. The // Counter tracks the "true" part of a conditional operator. The
...@@ -568,12 +570,12 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -568,12 +570,12 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
uint64_t TrueCount = setCount(PGO.getRegionCount(E)); uint64_t TrueCount = setCount(PGO.getRegionCount(E));
CountMap[E->getTrueExpr()] = TrueCount; CountMap[E->getTrueExpr()] = TrueCount;
Visit(E->getTrueExpr()); Visit(E->getTrueExpr());
uint64_t OutCount = PGO.getCurrentRegionCount(); uint64_t OutCount = CurrentCount;
uint64_t FalseCount = setCount(ParentCount - TrueCount); uint64_t FalseCount = setCount(ParentCount - TrueCount);
CountMap[E->getFalseExpr()] = FalseCount; CountMap[E->getFalseExpr()] = FalseCount;
Visit(E->getFalseExpr()); Visit(E->getFalseExpr());
OutCount += PGO.getCurrentRegionCount(); OutCount += CurrentCount;
setCount(OutCount); setCount(OutCount);
RecordNextStmtCount = true; RecordNextStmtCount = true;
...@@ -581,25 +583,25 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { ...@@ -581,25 +583,25 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
void VisitBinLAnd(const BinaryOperator *E) { void VisitBinLAnd(const BinaryOperator *E) {
RecordStmtCount(E); RecordStmtCount(E);
uint64_t ParentCount = PGO.getCurrentRegionCount(); uint64_t ParentCount = CurrentCount;
Visit(E->getLHS()); Visit(E->getLHS());
// Counter tracks the right hand side of a logical and operator. // Counter tracks the right hand side of a logical and operator.
uint64_t RHSCount = setCount(PGO.getRegionCount(E)); uint64_t RHSCount = setCount(PGO.getRegionCount(E));
CountMap[E->getRHS()] = RHSCount; CountMap[E->getRHS()] = RHSCount;
Visit(E->getRHS()); Visit(E->getRHS());
setCount(ParentCount + RHSCount - PGO.getCurrentRegionCount()); setCount(ParentCount + RHSCount - CurrentCount);
RecordNextStmtCount = true; RecordNextStmtCount = true;
} }
void VisitBinLOr(const BinaryOperator *E) { void VisitBinLOr(const BinaryOperator *E) {
RecordStmtCount(E); RecordStmtCount(E);
uint64_t ParentCount = PGO.getCurrentRegionCount(); uint64_t ParentCount = CurrentCount;
Visit(E->getLHS()); Visit(E->getLHS());
// Counter tracks the right hand side of a logical or operator. // Counter tracks the right hand side of a logical or operator.
uint64_t RHSCount = setCount(PGO.getRegionCount(E)); uint64_t RHSCount = setCount(PGO.getRegionCount(E));
CountMap[E->getRHS()] = RHSCount; CountMap[E->getRHS()] = RHSCount;
Visit(E->getRHS()); Visit(E->getRHS());
setCount(ParentCount + RHSCount - PGO.getCurrentRegionCount()); setCount(ParentCount + RHSCount - CurrentCount);
RecordNextStmtCount = true; RecordNextStmtCount = true;
} }
}; };
......
...@@ -60,11 +60,6 @@ public: ...@@ -60,11 +60,6 @@ public:
/// exits. /// exits.
void setCurrentRegionCount(uint64_t Count) { CurrentRegionCount = Count; } void setCurrentRegionCount(uint64_t Count) { CurrentRegionCount = Count; }
/// Indicate that the current region is never reached, and thus should have a
/// counter value of zero. This is important so that subsequent regions can
/// correctly track their parent counts.
void setCurrentRegionUnreachable() { setCurrentRegionCount(0); }
/// Check if an execution count is known for a given statement. If so, return /// Check if an execution count is known for a given statement. If so, return
/// true and put the value in Count; else return false. /// true and put the value in Count; else return false.
Optional<uint64_t> getStmtCount(const Stmt *S) { Optional<uint64_t> getStmtCount(const Stmt *S) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment