From dd5c1b62ec59630ff9e2adde00421c3f9c619933 Mon Sep 17 00:00:00 2001 From: Maximiliano Puccio Date: Tue, 17 Oct 2023 17:50:04 +0200 Subject: [PATCH] Flatten cell neighbours structure --- .../tracking/include/ITStracking/TimeFrame.h | 12 +++--- .../ITSMFT/ITS/tracking/src/TimeFrame.cxx | 6 +-- .../ITSMFT/ITS/tracking/src/TrackerTraits.cxx | 41 ++++++++++++------- 3 files changed, 36 insertions(+), 23 deletions(-) diff --git a/Detectors/ITSMFT/ITS/tracking/include/ITStracking/TimeFrame.h b/Detectors/ITSMFT/ITS/tracking/include/ITStracking/TimeFrame.h index da35b1714978a..82960c2d1315e 100644 --- a/Detectors/ITSMFT/ITS/tracking/include/ITStracking/TimeFrame.h +++ b/Detectors/ITSMFT/ITS/tracking/include/ITStracking/TimeFrame.h @@ -155,7 +155,8 @@ class TimeFrame std::vector>& getCellSeedsChi2() { return mCellSeedsChi2; } std::vector>& getCellsLookupTable(); - std::vector>>& getCellsNeighbours(); + std::vector>& getCellsNeighbours(); + std::vector>& getCellsNeighboursLUT(); std::vector>& getRoads(); std::vector& getTracks(int rof) { return mTracks[rof]; } std::vector& getTracksLabel(const int rof) { return mTracksLabel[rof]; } @@ -253,7 +254,8 @@ class TimeFrame std::vector> mCellSeeds; std::vector> mCellSeedsChi2; std::vector> mCellsLookupTable; - std::vector>> mCellsNeighbours; + std::vector> mCellsNeighbours; + std::vector> mCellsNeighboursLUT; std::vector> mRoads; std::vector> mTracksLabel; std::vector> mTracks; @@ -542,10 +544,8 @@ inline std::vector>& TimeFrame::getCellsLookupTable() return mCellsLookupTable; } -inline std::vector>>& TimeFrame::getCellsNeighbours() -{ - return mCellsNeighbours; -} +inline std::vector>& TimeFrame::getCellsNeighbours() { return mCellsNeighbours; } +inline std::vector>& TimeFrame::getCellsNeighboursLUT() { return mCellsNeighboursLUT; } inline std::vector>& TimeFrame::getRoads() { return mRoads; } diff --git a/Detectors/ITSMFT/ITS/tracking/src/TimeFrame.cxx b/Detectors/ITSMFT/ITS/tracking/src/TimeFrame.cxx index 24349e30b8017..247bf0441a691 100644 --- a/Detectors/ITSMFT/ITS/tracking/src/TimeFrame.cxx +++ b/Detectors/ITSMFT/ITS/tracking/src/TimeFrame.cxx @@ -255,6 +255,7 @@ void TimeFrame::initialise(const int iteration, const TrackingParameters& trkPar mCellSeedsChi2.resize(trkParam.CellsPerRoad()); mCellsLookupTable.resize(trkParam.CellsPerRoad() - 1); mCellsNeighbours.resize(trkParam.CellsPerRoad() - 1); + mCellsNeighboursLUT.resize(trkParam.CellsPerRoad() - 1); mCellLabels.resize(trkParam.CellsPerRoad()); mTracklets.resize(std::min(trkParam.TrackletsPerRoad(), maxLayers - 1)); mTrackletLabels.resize(trkParam.TrackletsPerRoad()); @@ -395,6 +396,7 @@ void TimeFrame::initialise(const int iteration, const TrackingParameters& trkPar if (iLayer < (int)mCells.size() - 1) { mCellsLookupTable[iLayer].clear(); mCellsNeighbours[iLayer].clear(); + mCellsNeighboursLUT[iLayer].clear(); } } } @@ -409,9 +411,7 @@ unsigned long TimeFrame::getArtefactsMemory() size += sizeof(Cell) * cells.size(); } for (auto& cellsN : mCellsNeighbours) { - for (auto& vec : cellsN) { - size += sizeof(int) * vec.size(); - } + size += sizeof(int) * cellsN.size(); } return size + sizeof(Road<5>) * mRoads.size(); } diff --git a/Detectors/ITSMFT/ITS/tracking/src/TrackerTraits.cxx b/Detectors/ITSMFT/ITS/tracking/src/TrackerTraits.cxx index ff4642a99f7c3..8149a9a7ce38c 100644 --- a/Detectors/ITSMFT/ITS/tracking/src/TrackerTraits.cxx +++ b/Detectors/ITSMFT/ITS/tracking/src/TrackerTraits.cxx @@ -15,6 +15,7 @@ #include "ITStracking/TrackerTraits.h" +#include #include #include @@ -394,8 +395,13 @@ void TrackerTraits::findCellsNeighbours(const int iteration) int layerCellsNum{static_cast(mTimeFrame->getCells()[iLayer].size())}; const int nextLayerCellsNum{static_cast(mTimeFrame->getCells()[iLayer + 1].size())}; - mTimeFrame->getCellsNeighbours()[iLayer].resize(nextLayerCellsNum); + mTimeFrame->getCellsNeighboursLUT()[iLayer].clear(); + mTimeFrame->getCellsNeighboursLUT()[iLayer].resize(nextLayerCellsNum, 0); + std::vector> cellsNeighbours; + cellsNeighbours.reserve(nextLayerCellsNum); + + std::vector> easyWay(nextLayerCellsNum); for (int iCell{0}; iCell < layerCellsNum; ++iCell) { const Cell& currentCell{mTimeFrame->getCells()[iLayer][iCell]}; @@ -426,8 +432,9 @@ void TrackerTraits::findCellsNeighbours(const int iteration) continue; } - mTimeFrame->getCellsNeighbours()[iLayer][iNextCell].push_back(iCell); - + mTimeFrame->getCellsNeighboursLUT()[iLayer][iNextCell]++; + cellsNeighbours.push_back(std::make_pair(iCell, iNextCell)); + easyWay[iNextCell].push_back(iCell); const int currentCellLevel{currentCell.getLevel()}; if (currentCellLevel >= nextCell.getLevel()) { @@ -435,6 +442,15 @@ void TrackerTraits::findCellsNeighbours(const int iteration) } } } + std::sort(cellsNeighbours.begin(), cellsNeighbours.end(), [](const std::pair& a, const std::pair& b) { + return a.second < b.second; + }); + mTimeFrame->getCellsNeighbours()[iLayer].clear(); + mTimeFrame->getCellsNeighbours()[iLayer].reserve(cellsNeighbours.size()); + for (auto& cellNeighboursIndex : cellsNeighbours) { + mTimeFrame->getCellsNeighbours()[iLayer].push_back(cellNeighboursIndex.first); + } + std::inclusive_scan(mTimeFrame->getCellsNeighboursLUT()[iLayer].begin(), mTimeFrame->getCellsNeighboursLUT()[iLayer].end(), mTimeFrame->getCellsNeighboursLUT()[iLayer].begin()); } } @@ -457,11 +473,11 @@ void TrackerTraits::findRoads(const int iteration) if (iLevel == 1) { continue; } - const int cellNeighboursNum{static_cast( - mTimeFrame->getCellsNeighbours()[iLayer - 1][iCell].size())}; + const int startNeighbourId{iCell ? mTimeFrame->getCellsNeighboursLUT()[iLayer - 1][iCell - 1] : 0}; + const int endNeighbourId{mTimeFrame->getCellsNeighboursLUT()[iLayer - 1][iCell]}; bool isFirstValidNeighbour = true; - for (int iNeighbourCell{0}; iNeighbourCell < cellNeighboursNum; ++iNeighbourCell) { - const int neighbourCellId = mTimeFrame->getCellsNeighbours()[iLayer - 1][iCell][iNeighbourCell]; + for (int iNeighbourCell{startNeighbourId}; iNeighbourCell < endNeighbourId; ++iNeighbourCell) { + const int neighbourCellId = mTimeFrame->getCellsNeighbours()[iLayer - 1][iNeighbourCell]; const Cell& neighbourCell = mTimeFrame->getCells()[iLayer - 1][neighbourCellId]; if (iLevel - 1 != neighbourCell.getLevel()) { continue; @@ -973,14 +989,11 @@ void TrackerTraits::traverseCellsTree(const int currentCellId, const int current mTimeFrame->getRoads().back().addCell(currentLayerId, currentCellId); if (currentLayerId > 0 && currentCellLevel > 1) { - const int cellNeighboursNum{static_cast( - mTimeFrame->getCellsNeighbours()[currentLayerId - 1][currentCellId].size())}; bool isFirstValidNeighbour = true; - - for (int iNeighbourCell{0}; iNeighbourCell < cellNeighboursNum; ++iNeighbourCell) { - - const int neighbourCellId = - mTimeFrame->getCellsNeighbours()[currentLayerId - 1][currentCellId][iNeighbourCell]; + const int startNeighbourId = currentCellId ? mTimeFrame->getCellsNeighboursLUT()[currentLayerId - 1][currentCellId - 1] : 0; + const int endNeighbourId = mTimeFrame->getCellsNeighboursLUT()[currentLayerId - 1][currentCellId]; + for (int iNeighbourCell{startNeighbourId}; iNeighbourCell < endNeighbourId; ++iNeighbourCell) { + const int neighbourCellId = mTimeFrame->getCellsNeighbours()[currentLayerId - 1][iNeighbourCell]; const Cell& neighbourCell = mTimeFrame->getCells()[currentLayerId - 1][neighbourCellId]; if (currentCellLevel - 1 != neighbourCell.getLevel()) {