diff --git a/Grinder/Grinder.pro b/Grinder/Grinder.pro index a9023ff04cdf16a94496dd44b59c552ee0238b24..d68e9859e8b1ac8d36b0968a4c52181043c8dfed 100644 --- a/Grinder/Grinder.pro +++ b/Grinder/Grinder.pro @@ -474,7 +474,9 @@ SOURCES += \ ml/external/msgs/InferImageMessage.cpp \ ml/properties/TilingTypeProperty.cpp \ cv/MatrixTiler.cpp \ - ml/tasks/InferenceTileQueue.cpp + ml/tasks/InferenceTileQueue.cpp \ + ml/tasks/MachineLearningTrainingTask.cpp \ + ml/tasks/MachineLearningInferenceTask.cpp HEADERS += \ ui/mainwnd/GrinderWindow.h \ @@ -1033,7 +1035,9 @@ HEADERS += \ ml/properties/TilingTypeProperty.h \ ml/MLTypes.h \ cv/MatrixTiler.h \ - ml/tasks/InferenceTileQueue.h + ml/tasks/InferenceTileQueue.h \ + ml/tasks/MachineLearningTrainingTask.h \ + ml/tasks/MachineLearningInferenceTask.h FORMS += \ ui/mainwnd/GrinderWindow.ui \ diff --git a/Grinder/Version.h b/Grinder/Version.h index fdb2442411953934622d2827320e7737f748508d..9bc3062c581b981e052313e4fa9db36f7469fd81 100644 --- a/Grinder/Version.h +++ b/Grinder/Version.h @@ -10,14 +10,14 @@ #define GRNDR_INFO_TITLE "Grinder" #define GRNDR_INFO_COPYRIGHT "Copyright (c) WWU Muenster" -#define GRNDR_INFO_DATE "07.10.2019" +#define GRNDR_INFO_DATE "08.10.2019" #define GRNDR_INFO_COMPANY "WWU Muenster" #define GRNDR_INFO_WEBSITE "http://www.uni-muenster.de" #define GRNDR_VERSION_MAJOR 0 #define GRNDR_VERSION_MINOR 15 #define GRNDR_VERSION_REVISION 0 -#define GRNDR_VERSION_BUILD 399 +#define GRNDR_VERSION_BUILD 400 namespace grndr { diff --git a/Grinder/ml/barista/tasks/BaristaInferenceTask.cpp b/Grinder/ml/barista/tasks/BaristaInferenceTask.cpp index 0f42adc6936bf18d031fe8bd799669dccb56a3ff..ec95c4981cf9bcf6c1ff86df94d9c71cdf0b1ce4 100644 --- a/Grinder/ml/barista/tasks/BaristaInferenceTask.cpp +++ b/Grinder/ml/barista/tasks/BaristaInferenceTask.cpp @@ -29,35 +29,6 @@ void BaristaInferenceTask::registerMessageHandlers() registerMessageHandler<BaristaInferImageMessage>(BaristaInferImageMessage::Key, &BaristaInferenceTask::handleInferImageMessage); } -void BaristaInferenceTask::inferImage(const cv::Mat& imageData, const ImageReference* imageRef) -{ - addLogMessage(QString{"\tPerforming inference on '%1'..."}.arg(imageRef->getImageFileName())); - - if (imageData.empty()) - throw BaristaException{_EXCPT(QString{"Failed to get data for '%1'"}.arg(imageRef->getImageFilePath()))}; - - // Fill the image tile queue - _tileQueue.beginQueue(_imageTiler.tile(imageData)); - - // Send the first image tile for inference - sendInferImageMessage(); -} - -void BaristaInferenceTask::processEngineEnd(EngineExecutionContext& ctx, const MachineLearningTaskData& data) -{ - MachineLearningTask::processEngineEnd(ctx, data); - - if (!ctx.wasAborted()) - { - setProgress(1.0f); - addLogMessage("All images have been processed"); - } - - // The inference has finished, so break the Barista connection and finish the task - shutdownTask(); - finishTask(true); -} - void BaristaInferenceTask::createBaristaNetworkContext() { _networkContext = std::make_unique<BaristaNetworkContext>(BaristaNetworkContext::NetworkType::Inference, _outputDirectory, _remoteDirectory); @@ -85,6 +56,13 @@ void BaristaInferenceTask::verifyTask() const throw TaskException{this, _EXCPT("No state selected")}; } +void BaristaInferenceTask::finishInference() +{ + // The inference has finished, so break the Barista connection and finish the task + shutdownTask(); + finishTask(true); +} + void BaristaInferenceTask::sendLoadNetworkMessage() { auto network = _networkContext->resolveRemoteFile(_network->networkInfo().getInferenceNetwork()); @@ -198,20 +176,6 @@ void BaristaInferenceTask::processInferenceResult(const BaristaInferImageMessage throw BaristaException{_EXCPT("Invalid dimensions")}; } -void BaristaInferenceTask::finishInferenceResult() -{ - ImageInferenceResults results; - - for (unsigned int tagIndex = 0; tagIndex < _inputImageTags->tags().size(); ++tagIndex) - { - auto imageTag = _inputImageTags->tags().at(tagIndex).get(); - results[imageTag] = _imageTiler.combine(_tileQueue.getResultTiles(imageTag)); - } - - // Inform about the arrived results via an event - emit inferImageFinished(results); -} - cv::Mat BaristaInferenceTask::extractProbabilityData(cv::Mat& probData, QSize dataSize, unsigned int index) const { // Extract the probability data diff --git a/Grinder/ml/barista/tasks/BaristaInferenceTask.h b/Grinder/ml/barista/tasks/BaristaInferenceTask.h index d85a37123cdbe34e28c4f98ce37a4be05a4158a4..500780beb638c66f5f61c75634830351e24f17c2 100644 --- a/Grinder/ml/barista/tasks/BaristaInferenceTask.h +++ b/Grinder/ml/barista/tasks/BaristaInferenceTask.h @@ -9,7 +9,7 @@ #include <opencv2/core.hpp> #include "BaristaTask.h" -#include "ml/tasks/InferenceTileQueue.h" +#include "ml/tasks/MachineLearningInferenceTask.h" #include "ml/barista/msgs/BaristaLoadNetworkMessage.h" namespace grndr @@ -17,7 +17,7 @@ namespace grndr class BaristaInferImageMessage; class ImageTag; - class BaristaInferenceTask : public BaristaTask<BaristaInferenceTask> + class BaristaInferenceTask : public BaristaTask<MachineLearningInferenceTask, BaristaInferenceTask> { Q_OBJECT @@ -30,12 +30,6 @@ namespace grndr public: virtual void registerMessageHandlers() override; - public: - virtual void inferImage(const cv::Mat& imageData, const ImageReference* imageRef) override; - - public: - virtual void processEngineEnd(EngineExecutionContext& ctx, const MachineLearningTaskData& data) override; - protected: virtual void createBaristaNetworkContext() override; virtual void prepareBaristaNetworkContext() override; @@ -45,6 +39,10 @@ namespace grndr protected: virtual void verifyTask() const override; + protected: + virtual void performInference() override { sendInferImageMessage(); } + virtual void finishInference() override; + private: void sendLoadNetworkMessage(); void sendInferImageMessage(); @@ -55,7 +53,6 @@ namespace grndr private: void processInferenceResult(const BaristaInferImageMessage* result, QString outputName); - void finishInferenceResult(); cv::Mat extractProbabilityData(cv::Mat& probData, QSize dataSize, unsigned int index) const; @@ -69,10 +66,6 @@ namespace grndr private: std::vector<BaristaLoadNetworkMessage::DataInformation> _networkInputs; std::vector<BaristaLoadNetworkMessage::DataInformation> _networkOutputs; - - private: - InferenceTileQueue _tileQueue; - unsigned int _currentTileIndex{0}; }; } diff --git a/Grinder/ml/barista/tasks/BaristaTask.h b/Grinder/ml/barista/tasks/BaristaTask.h index caf8d1dd55f9afbb987690ca214c59733e2cb2d0..37e76b239099edfab3501f73a5f3541bdcbcf1a6 100644 --- a/Grinder/ml/barista/tasks/BaristaTask.h +++ b/Grinder/ml/barista/tasks/BaristaTask.h @@ -14,9 +14,11 @@ namespace grndr { - template<typename ClassType> - class BaristaTask : public MachineLearningTask, public NetworkMessageHandler<ClassType, BaristaMessage> + template<typename BaseType, typename ClassType> + class BaristaTask : public BaseType, public NetworkMessageHandler<ClassType, BaristaMessage> { + static_assert(std::is_base_of<MachineLearningTask, BaseType>::value, "BaseType must be derived from MachineLearningTask"); + public: static const char* Serialization_Value_BaristaPort; static const char* Serialization_Value_LibraryPath; @@ -24,6 +26,7 @@ namespace grndr static const char* Serialization_Value_RemoteDirectory; public: + using base_type = BaseType; using class_type = ClassType; using handler_type = NetworkMessageHandler<ClassType, BaristaMessage>; diff --git a/Grinder/ml/barista/tasks/BaristaTask.impl.h b/Grinder/ml/barista/tasks/BaristaTask.impl.h index 0e8529faa27f79e94c3615243e4f1ec3dcb91003..618ad26e44e210542176a65acfdb018dac558162 100644 --- a/Grinder/ml/barista/tasks/BaristaTask.impl.h +++ b/Grinder/ml/barista/tasks/BaristaTask.impl.h @@ -11,45 +11,45 @@ #include "ml/barista/msgs/BaristaShutdownMessage.h" -template<typename ClassType> -const char* BaristaTask<ClassType>::Serialization_Value_BaristaPort = "BaristaPort"; -template<typename ClassType> -const char* BaristaTask<ClassType>::Serialization_Value_LibraryPath = "LibraryPath"; -template<typename ClassType> -const char* BaristaTask<ClassType>::Serialization_Value_Network = "Network"; -template<typename ClassType> -const char* BaristaTask<ClassType>::Serialization_Value_RemoteDirectory = "RemoteDirectory"; - -template<typename ClassType> -BaristaTask<ClassType>::BaristaTask(class_type* handlerTarget, TaskPool* taskPool, TaskType type, QString name) : MachineLearningTask(taskPool, type, Task::Capability::CanBeStopped|Task::Capability::HasProgress, name), NetworkMessageHandler<ClassType, BaristaMessage>(&_baristaInterface, handlerTarget) +template<typename BaseType, typename ClassType> +const char* BaristaTask<BaseType, ClassType>::Serialization_Value_BaristaPort = "BaristaPort"; +template<typename BaseType, typename ClassType> +const char* BaristaTask<BaseType, ClassType>::Serialization_Value_LibraryPath = "LibraryPath"; +template<typename BaseType, typename ClassType> +const char* BaristaTask<BaseType, ClassType>::Serialization_Value_Network = "Network"; +template<typename BaseType, typename ClassType> +const char* BaristaTask<BaseType, ClassType>::Serialization_Value_RemoteDirectory = "RemoteDirectory"; + +template<typename BaseType, typename ClassType> +BaristaTask<BaseType, ClassType>::BaristaTask(class_type* handlerTarget, TaskPool* taskPool, TaskType type, QString name) : BaseType(taskPool, type, Task::Capability::CanBeStopped|Task::Capability::HasProgress, name), NetworkMessageHandler<ClassType, BaristaMessage>(&_baristaInterface, handlerTarget) { // Register the task as a message handler _baristaInterface.registerMessageHandler(this); // When the task has been stopped, immediately finish it - connect(this, &MachineLearningTask::taskStopped, [this]() { finishTask(false); }); + this->connect(this, &base_type::taskStopped, [this]() { this->finishTask(false); }); // The Barista interface will notify us when it is ready or an error occurred - connect(&_baristaInterface, &BaristaInterface::clientReady, this, &BaristaTask<ClassType>::baristaClientReady); - connect(&_baristaInterface, &BaristaInterface::clientFailure, this, &BaristaTask<ClassType>::baristaClientFailure); + this->connect(&_baristaInterface, &BaristaInterface::clientReady, this, &BaristaTask<BaseType, ClassType>::baristaClientReady); + this->connect(&_baristaInterface, &BaristaInterface::clientFailure, this, &BaristaTask<BaseType, ClassType>::baristaClientFailure); } -template<typename ClassType> -void BaristaTask<ClassType>::registerMessageHandlers() +template<typename BaseType, typename ClassType> +void BaristaTask<BaseType, ClassType>::registerMessageHandlers() { } -template<typename ClassType> -void BaristaTask<ClassType>::setNetwork(BaristaNetwork* network) +template<typename BaseType, typename ClassType> +void BaristaTask<BaseType, ClassType>::setNetwork(BaristaNetwork* network) { _network = network; } -template<typename ClassType> -void BaristaTask<ClassType>::serialize(SerializationContext& ctx) const +template<typename BaseType, typename ClassType> +void BaristaTask<BaseType, ClassType>::serialize(SerializationContext& ctx) const { - MachineLearningTask::serialize(ctx); + base_type::serialize(ctx); // Serialize values ctx.settings()(Serialization_Value_BaristaPort) = _baristaPort; @@ -58,10 +58,10 @@ void BaristaTask<ClassType>::serialize(SerializationContext& ctx) const ctx.settings()(Serialization_Value_RemoteDirectory) = _remoteDirectory; } -template<typename ClassType> -void BaristaTask<ClassType>::deserialize(DeserializationContext& ctx) +template<typename BaseType, typename ClassType> +void BaristaTask<BaseType, ClassType>::deserialize(DeserializationContext& ctx) { - MachineLearningTask::deserialize(ctx); + base_type::deserialize(ctx); // Deserialize values QString networkName = ctx.settings()(Serialization_Value_Network).toString(); @@ -72,10 +72,10 @@ void BaristaTask<ClassType>::deserialize(DeserializationContext& ctx) _remoteDirectory = ctx.settings()(Serialization_Value_RemoteDirectory).toString(); } -template<typename ClassType> -void BaristaTask<ClassType>::verifyTask() const +template<typename BaseType, typename ClassType> +void BaristaTask<BaseType, ClassType>::verifyTask() const { - MachineLearningTask::verifyTask(); + base_type::verifyTask(); if (_libraryPath.isEmpty()) throw TaskException{this, _EXCPT("No library path provided")}; @@ -84,44 +84,44 @@ void BaristaTask<ClassType>::verifyTask() const throw TaskException{this, _EXCPT("No Barista network provided")}; } -template<typename ClassType> -void BaristaTask<ClassType>::execute() +template<typename BaseType, typename ClassType> +void BaristaTask<BaseType, ClassType>::execute() { initializeTask(); - MachineLearningTask::execute(); + base_type::execute(); } -template<typename ClassType> -void BaristaTask<ClassType>::stop() +template<typename BaseType, typename ClassType> +void BaristaTask<BaseType, ClassType>::stop() { shutdownTask(); - MachineLearningTask::stop(); + base_type::stop(); } -template<typename ClassType> -void BaristaTask<ClassType>::update() +template<typename BaseType, typename ClassType> +void BaristaTask<BaseType, ClassType>::update() { - if (_taskState > TaskState::Initializing) + if (this->_taskState > MachineLearningTask::TaskState::Initializing) _baristaInterface.process(); } -template<typename ClassType> -void BaristaTask<ClassType>::finish(bool succeeded) +template<typename BaseType, typename ClassType> +void BaristaTask<BaseType, ClassType>::finish(bool succeeded) { Q_UNUSED(succeeded); shutdownTask(); } -template<typename ClassType> -void BaristaTask<ClassType>::initializeTask() +template<typename BaseType, typename ClassType> +void BaristaTask<BaseType, ClassType>::initializeTask() { - changeTaskState(TaskState::Initializing); + this->changeTaskState(MachineLearningTask::TaskState::Initializing); // Create and compile the Barista network so that all files are ready and in place - addLogMessage("Preparing Barista network..."); + this->addLogMessage("Preparing Barista network..."); if (!_network) throw TaskException{this, _EXCPT("No Barista network selected")}; @@ -142,13 +142,13 @@ void BaristaTask<ClassType>::initializeTask() _baristaInterface.waitForConnection(_baristaPort); // Enter the awaiting connection state - changeTaskState(TaskState::AwaitingConnection, "Waiting for Barista to connect..."); + this->changeTaskState(MachineLearningTask::TaskState::AwaitingConnection, "Waiting for Barista to connect..."); } -template<typename ClassType> -void BaristaTask<ClassType>::shutdownTask(bool sendShutdownMsg) +template<typename BaseType, typename ClassType> +void BaristaTask<BaseType, ClassType>::shutdownTask(bool sendShutdownMsg) { - changeTaskState(TaskState::Shutdown, "Shutting down Barista connection..."); + this->changeTaskState(MachineLearningTask::TaskState::Shutdown, "Shutting down Barista connection..."); // Clean up the Barista network if (_network && _networkContext) @@ -165,8 +165,8 @@ void BaristaTask<ClassType>::shutdownTask(bool sendShutdownMsg) _baristaInterface.stop(); } -template<typename ClassType> -void BaristaTask<ClassType>::reportError(QString error, const BaristaMessage* message) +template<typename BaseType, typename ClassType> +void BaristaTask<BaseType, ClassType>::reportError(QString error, const BaristaMessage* message) { if (message) { @@ -176,38 +176,38 @@ void BaristaTask<ClassType>::reportError(QString error, const BaristaMessage* me error += QString{" (%1)"}.arg(baristaError); } - addLogMessage(error); + this->addLogMessage(error); shutdownTask(); - finishTask(false); + this->finishTask(false); } -template<typename ClassType> -void BaristaTask<ClassType>::reportUnexpectedMessage(const BaristaMessage* message) +template<typename BaseType, typename ClassType> +void BaristaTask<BaseType, ClassType>::reportUnexpectedMessage(const BaristaMessage* message) { - reportError(QString{"Unexpected message received (%1) during state %2"}.arg(message->getKey()).arg(_taskState)); + this->reportError(QString{"Unexpected message received (%1) during state %2"}.arg(message->getKey()).arg(this->_taskState)); } -template<typename ClassType> -void BaristaTask<ClassType>::prepareBaristaNetworkContext() +template<typename BaseType, typename ClassType> +void BaristaTask<BaseType, ClassType>::prepareBaristaNetworkContext() { // Output count (based on available image tags) - unsigned int imageTagsCount = _inputImageTags ? _inputImageTags->tags().size() : 0; + unsigned int imageTagsCount = this->_inputImageTags ? this->_inputImageTags->tags().size() : 0; _networkContext->addVariable(BaristaNetworkContext::Variable_OutputCount, imageTagsCount); _networkContext->addVariable(BaristaNetworkContext::Variable_OutputCountPlus1, imageTagsCount + 1); } -template<typename ClassType> -void BaristaTask<ClassType>::sendShutdownMessage() +template<typename BaseType, typename ClassType> +void BaristaTask<BaseType, ClassType>::sendShutdownMessage() { auto shutdown = this->template createMessage<BaristaShutdownMessage>(); _baristaInterface.sendMessage(*shutdown); } -template<typename ClassType> -void BaristaTask<ClassType>::baristaClientReady(QString ip) +template<typename BaseType, typename ClassType> +void BaristaTask<BaseType, ClassType>::baristaClientReady(QString ip) { - changeTaskState(TaskState::Connected, QString{"\tConnected to Barista at '%1'"}.arg(ip)); + this->changeTaskState(MachineLearningTask::TaskState::Connected, QString{"\tConnected to Barista at '%1'"}.arg(ip)); // Let the task-specific implementation handle the rest from here on baristaReady(); diff --git a/Grinder/ml/barista/tasks/BaristaTrainingTask.cpp b/Grinder/ml/barista/tasks/BaristaTrainingTask.cpp index 85a5799b7021a1fa82915dfeee27e64c5b28104e..4aaaa6a7b8149f32e74a91e2e348e1f8de53c116 100644 --- a/Grinder/ml/barista/tasks/BaristaTrainingTask.cpp +++ b/Grinder/ml/barista/tasks/BaristaTrainingTask.cpp @@ -34,55 +34,6 @@ void BaristaTrainingTask::registerMessageHandlers() registerMessageHandler<BaristaFinishTrainingMessage>(BaristaFinishTrainingMessage::Key, &BaristaTrainingTask::handleFinishTrainingMessage); } -void BaristaTrainingTask::processEngineStart(EngineExecutionContext& ctx, const MachineLearningTaskData& data) -{ - BaristaTask::processEngineStart(ctx, data); - - // Create the HDF5 file in the output directory - QFileInfo fi{_outputDirectory, FILE_TRAINING_DATA_HDF5}; - _h5Export = std::make_unique<HDF5Export>(fi.filePath()); - - HDF5Export::ExportFlags exportFlags = HDF5Export::ExportFlag::ExportTags|HDF5Export::ExportFlag::MergeTags; - - if (_network->networkInfo().requiresGrayscale()) - exportFlags |= HDF5Export::ExportFlag::ExportAsGrayscale; - - _h5Export->initExport(data.imageTiler.getTileSize(), ctx.imageReferences().size() * data.imageTiler.tiles().size(), 1, exportFlags); -} - -void BaristaTrainingTask::processEnginePass(EngineExecutionContext& ctx, const MachineLearningTaskData& data) -{ - BaristaTask::processEnginePass(ctx, data); - - if (!_h5Export) - { - reportError("The HDF5 file isn't ready for exporting"); - return; - } - - std::vector<cv::Mat> imageDataTiles = data.imageTiler.tile(data.imageData); - std::vector<cv::Mat> imageTagsDataTiles; - - if (!data.imageTagsData.empty()) - imageTagsDataTiles = data.imageTiler.tile(data.imageTagsData); - - for (unsigned int i = 0; i < imageDataTiles.size(); ++i) - { - if (!imageTagsDataTiles.empty()) - _h5Export->exportImageEx(imageDataTiles[i], {imageTagsDataTiles[i]}); - else - _h5Export->exportImage(imageDataTiles[i]); - } -} - -void BaristaTrainingTask::processEngineEnd(EngineExecutionContext& ctx, const MachineLearningTaskData& data) -{ - BaristaTask::processEngineEnd(ctx, data); - - // We no longer need the HDF5 export - _h5Export = nullptr; -} - void BaristaTrainingTask::serialize(SerializationContext& ctx) const { BaristaTask::serialize(ctx); @@ -129,6 +80,21 @@ void BaristaTrainingTask::baristaReady() sendStartTrainingMessage(); } +QString BaristaTrainingTask::getTrainingDataFile() const +{ + return QFileInfo{_outputDirectory, FILE_TRAINING_DATA_HDF5}.filePath(); +} + +HDF5Export::ExportFlags BaristaTrainingTask::getTrainingDataFileFlags() const +{ + HDF5Export::ExportFlags exportFlags = HDF5Export::ExportFlag::ExportTags|HDF5Export::ExportFlag::MergeTags; + + if (_network->networkInfo().requiresGrayscale()) + exportFlags |= HDF5Export::ExportFlag::ExportAsGrayscale; + + return exportFlags; +} + void BaristaTrainingTask::sendStartTrainingMessage() { auto sessionDir = _networkContext->getRemoteDirectory(); diff --git a/Grinder/ml/barista/tasks/BaristaTrainingTask.h b/Grinder/ml/barista/tasks/BaristaTrainingTask.h index 1de01090a0a3f9b4b6fa2d14a5b79566eed570a9..38563dcb3187b44193340f4863803088a5b1861c 100644 --- a/Grinder/ml/barista/tasks/BaristaTrainingTask.h +++ b/Grinder/ml/barista/tasks/BaristaTrainingTask.h @@ -7,11 +7,11 @@ #define BARISTATRAININGTASK_H #include "BaristaTask.h" -#include "project/exporters/HDF5Export.h" +#include "ml/tasks/MachineLearningTrainingTask.h" namespace grndr { - class BaristaTrainingTask : public BaristaTask<BaristaTrainingTask> + class BaristaTrainingTask : public BaristaTask<MachineLearningTrainingTask, BaristaTrainingTask> { Q_OBJECT @@ -28,11 +28,6 @@ namespace grndr public: virtual void registerMessageHandlers() override; - public: - virtual void processEngineStart(EngineExecutionContext& ctx, const MachineLearningTaskData& data) override; - virtual void processEnginePass(EngineExecutionContext& ctx, const MachineLearningTaskData& data) override; - virtual void processEngineEnd(EngineExecutionContext& ctx, const MachineLearningTaskData& data) override; - public: unsigned int getMaxIterations() const { return _maxIterations; } void setMaxIterations(unsigned int maxIter) { _maxIterations = maxIter; } @@ -51,6 +46,13 @@ namespace grndr virtual void baristaReady() override; + protected: + virtual void taskError(QString error) override { reportError(error); } + + protected: + virtual QString getTrainingDataFile() const override; + virtual HDF5Export::ExportFlags getTrainingDataFileFlags() const override; + private: void sendStartTrainingMessage(); @@ -70,9 +72,6 @@ namespace grndr unsigned int _maxIterations{5000}; unsigned int _displayInterval{100}; unsigned int _snapshotInterval{1000}; - - private: - std::unique_ptr<HDF5Export> _h5Export; }; } diff --git a/Grinder/ml/external/tasks/ExternalClassifierInferenceTask.cpp b/Grinder/ml/external/tasks/ExternalClassifierInferenceTask.cpp index 2ba829205c74eb01782c7631ff2b801811bb43b6..68202f3cb0a3442d890f92383dddd04ee7fe67ae 100644 --- a/Grinder/ml/external/tasks/ExternalClassifierInferenceTask.cpp +++ b/Grinder/ml/external/tasks/ExternalClassifierInferenceTask.cpp @@ -29,33 +29,7 @@ void ExternalClassifierInferenceTask::registerMessageHandlers() void ExternalClassifierInferenceTask::inferImage(const cv::Mat& imageData, const ImageReference* imageRef) { if (isProcessRunning()) - { - addLogMessage(QString{"\tPerforming inference on '%1'..."}.arg(imageRef->getImageFileName())); - - if (imageData.empty()) - throw TaskException{this, _EXCPT(QString{"Failed to get data for '%1'"}.arg(imageRef->getImageFilePath()))}; - - // Fill the image tile queue - _tileQueue.beginQueue(_imageTiler.tile(imageData)); - - // Send the first image tile for inference - sendInferImageMessage(); - } -} - -void ExternalClassifierInferenceTask::processEngineEnd(EngineExecutionContext& ctx, const MachineLearningTaskData& data) -{ - MachineLearningTask::processEngineEnd(ctx, data); - - if (!ctx.wasAborted() && isProcessRunning()) - { - setProgress(1.0f); - addLogMessage("All images have been processed"); - } - - // The inference has finished, so break the Barista connection and finish the task - shutdownTask(); - finishTask(true); + ExternalClassifierTask::inferImage(imageData, imageRef); } void ExternalClassifierInferenceTask::connectionReady() @@ -73,6 +47,13 @@ void ExternalClassifierInferenceTask::verifyTask() const throw TaskException{this, _EXCPT("No state selected")}; } +void ExternalClassifierInferenceTask::finishInference() +{ + // The inference has finished, so finish the task + shutdownTask(); + finishTask(true); +} + void ExternalClassifierInferenceTask::sendStartInferenceMessage() { if (isProcessRunning()) @@ -159,17 +140,3 @@ void ExternalClassifierInferenceTask::processInferenceResult(const AckMessage* a for (unsigned int tagIndex = 0; tagIndex < _inputImageTags->tags().size(); ++tagIndex) _tileQueue.addResultTile(_inputImageTags->tags().at(tagIndex).get(), ackMessage->matrixData(size.width(), size.height(), true, tagIndex, CV_32FC1)); } - -void ExternalClassifierInferenceTask::finishInferenceResult() -{ - ImageInferenceResults results; - - for (unsigned int tagIndex = 0; tagIndex < _inputImageTags->tags().size(); ++tagIndex) - { - auto imageTag = _inputImageTags->tags().at(tagIndex).get(); - results[imageTag] = _imageTiler.combine(_tileQueue.getResultTiles(imageTag)); - } - - // Inform about the arrived results via an event - emit inferImageFinished(results); -} diff --git a/Grinder/ml/external/tasks/ExternalClassifierInferenceTask.h b/Grinder/ml/external/tasks/ExternalClassifierInferenceTask.h index d36ee08a6c7ecdcf47e6bdf498dc9dfa363fdf26..e6a9a65c87124fad28b79a01d16891628d7a6c5c 100644 --- a/Grinder/ml/external/tasks/ExternalClassifierInferenceTask.h +++ b/Grinder/ml/external/tasks/ExternalClassifierInferenceTask.h @@ -7,11 +7,11 @@ #define EXTERNALCLASSIFIERINFERENCETASK_H #include "ExternalClassifierTask.h" -#include "ml/tasks/InferenceTileQueue.h" +#include "ml/tasks/MachineLearningInferenceTask.h" namespace grndr { - class ExternalClassifierInferenceTask : public ExternalClassifierTask<ExternalClassifierInferenceTask> + class ExternalClassifierInferenceTask : public ExternalClassifierTask<MachineLearningInferenceTask, ExternalClassifierInferenceTask> { Q_OBJECT @@ -27,15 +27,18 @@ namespace grndr public: virtual void inferImage(const cv::Mat& imageData, const ImageReference* imageRef) override; - public: - virtual void processEngineEnd(EngineExecutionContext& ctx, const MachineLearningTaskData& data) override; - protected: virtual void connectionReady() override; protected: virtual void verifyTask() const override; + protected: + virtual void performInference() override { sendInferImageMessage(); } + virtual void finishInference() override; + + virtual bool completeTaskOnEnd() const override { return isProcessRunning(); } + private: void sendStartInferenceMessage(); void sendInferImageMessage(); @@ -48,7 +51,6 @@ namespace grndr private: void processInferenceResult(const AckMessage* ackMessage); - void finishInferenceResult(); private: enum InferenceTaskState @@ -56,10 +58,6 @@ namespace grndr StartInference = TypeSpecificBase, InferImages, }; - - private: - InferenceTileQueue _tileQueue; - unsigned int _currentTileIndex{0}; }; } diff --git a/Grinder/ml/external/tasks/ExternalClassifierTask.h b/Grinder/ml/external/tasks/ExternalClassifierTask.h index 32ece19cbb95bb06dc39be8341a255f42d849082..953cc2d3956221964884ea7a4186a0088701206b 100644 --- a/Grinder/ml/external/tasks/ExternalClassifierTask.h +++ b/Grinder/ml/external/tasks/ExternalClassifierTask.h @@ -13,9 +13,11 @@ namespace grndr { - template<typename ClassType> - class ExternalClassifierTask : public MachineLearningTask, public CommandInterfaceHandler<ClassType> + template<typename BaseType, typename ClassType> + class ExternalClassifierTask : public BaseType, public CommandInterfaceHandler<ClassType> { + static_assert(std::is_base_of<MachineLearningTask, BaseType>::value, "BaseType must be derived from MachineLearningTask"); + public: static const char* Serialization_Value_Command; static const char* Serialization_Value_Arguments; @@ -24,6 +26,7 @@ namespace grndr static const char* Serialization_Value_ControlPort; public: + using base_type = BaseType; using class_type = ClassType; using handler_type = CommandInterfaceHandler<ClassType>; diff --git a/Grinder/ml/external/tasks/ExternalClassifierTask.impl.h b/Grinder/ml/external/tasks/ExternalClassifierTask.impl.h index 86f97d36c2646e0d8861e41f07abeb07a997f878..d9578e6483c08abb9392bcfa5ec9174392d922c2 100644 --- a/Grinder/ml/external/tasks/ExternalClassifierTask.impl.h +++ b/Grinder/ml/external/tasks/ExternalClassifierTask.impl.h @@ -12,30 +12,30 @@ #include "network/cmd/msgs/GrinderInfoMessage.h" #include "network/cmd/msgs/StatusMessage.h" -template<typename ClassType> -const char* ExternalClassifierTask<ClassType>::Serialization_Value_Command = "Command"; -template<typename ClassType> -const char* ExternalClassifierTask<ClassType>::Serialization_Value_Arguments = "Arguments"; -template<typename ClassType> -const char* ExternalClassifierTask<ClassType>::Serialization_Value_UserData = "UserData"; -template<typename ClassType> -const char* ExternalClassifierTask<ClassType>::Serialization_Value_CommandPort = "CommandPort"; -template<typename ClassType> -const char* ExternalClassifierTask<ClassType>::Serialization_Value_ControlPort = "ControlPort"; - -template<typename ClassType> -ExternalClassifierTask<ClassType>::ExternalClassifierTask(class_type* handlerTarget, TaskPool* taskPool, TaskType type, QString name) : MachineLearningTask(taskPool, type, Task::Capability::CanBeStopped|Task::Capability::HasProgress, name), CommandInterfaceHandler<ClassType>(&_cmdInterface, handlerTarget), +template<typename BaseType, typename ClassType> +const char* ExternalClassifierTask<BaseType, ClassType>::Serialization_Value_Command = "Command"; +template<typename BaseType, typename ClassType> +const char* ExternalClassifierTask<BaseType, ClassType>::Serialization_Value_Arguments = "Arguments"; +template<typename BaseType, typename ClassType> +const char* ExternalClassifierTask<BaseType, ClassType>::Serialization_Value_UserData = "UserData"; +template<typename BaseType, typename ClassType> +const char* ExternalClassifierTask<BaseType, ClassType>::Serialization_Value_CommandPort = "CommandPort"; +template<typename BaseType, typename ClassType> +const char* ExternalClassifierTask<BaseType, ClassType>::Serialization_Value_ControlPort = "ControlPort"; + +template<typename BaseType, typename ClassType> +ExternalClassifierTask<BaseType, ClassType>::ExternalClassifierTask(class_type* handlerTarget, TaskPool* taskPool, TaskType type, QString name) : BaseType(taskPool, type, Task::Capability::CanBeStopped|Task::Capability::HasProgress, name), CommandInterfaceHandler<ClassType>(&_cmdInterface, handlerTarget), _cmdInterface{CommandInterface::CoreType::Client} { // Register the task as a message handler _cmdInterface.registerMessageHandler(this); // When the task has been stopped, immediately finish it - connect(this, &MachineLearningTask::taskStopped, [this]() { finishTask(false); }); + this->connect(this, &base_type::taskStopped, [this]() { this->finishTask(false); }); } -template<typename ClassType> -void ExternalClassifierTask<ClassType>::registerMessageHandlers() +template<typename BaseType, typename ClassType> +void ExternalClassifierTask<BaseType, ClassType>::registerMessageHandlers() { handler_type::registerMessageHandlers(); @@ -43,10 +43,10 @@ void ExternalClassifierTask<ClassType>::registerMessageHandlers() this->template registerMessageHandler<StatusMessage>(StatusMessage::Command, &ExternalClassifierTask::handleStatusMessage); } -template<typename ClassType> -void ExternalClassifierTask<ClassType>::serialize(SerializationContext& ctx) const +template<typename BaseType, typename ClassType> +void ExternalClassifierTask<BaseType, ClassType>::serialize(SerializationContext& ctx) const { - MachineLearningTask::serialize(ctx); + base_type::serialize(ctx); // Serialize values ctx.settings()(Serialization_Value_Command) = _executable; @@ -56,10 +56,10 @@ void ExternalClassifierTask<ClassType>::serialize(SerializationContext& ctx) con ctx.settings()(Serialization_Value_ControlPort) = _controlPort; } -template<typename ClassType> -void ExternalClassifierTask<ClassType>::deserialize(DeserializationContext& ctx) +template<typename BaseType, typename ClassType> +void ExternalClassifierTask<BaseType, ClassType>::deserialize(DeserializationContext& ctx) { - MachineLearningTask::deserialize(ctx); + base_type::deserialize(ctx); // Deserialize values _executable = ctx.settings()(Serialization_Value_Command).toString(); @@ -69,10 +69,10 @@ void ExternalClassifierTask<ClassType>::deserialize(DeserializationContext& ctx) _controlPort = ctx.settings()(Serialization_Value_CommandPort).toUInt(); } -template<typename ClassType> -void ExternalClassifierTask<ClassType>::verifyTask() const +template<typename BaseType, typename ClassType> +void ExternalClassifierTask<BaseType, ClassType>::verifyTask() const { - MachineLearningTask::verifyTask(); + base_type::verifyTask(); if (_executable.isEmpty()) throw TaskException{this, _EXCPT("No command to execute provided")}; @@ -81,8 +81,8 @@ void ExternalClassifierTask<ClassType>::verifyTask() const throw TaskException{this, _EXCPT("The task is already running")}; } -template<typename ClassType> -void ExternalClassifierTask<ClassType>::execute() +template<typename BaseType, typename ClassType> +void ExternalClassifierTask<BaseType, ClassType>::execute() { // Create the task's external process and start it _process = std::make_unique<QProcess>(); @@ -91,8 +91,8 @@ void ExternalClassifierTask<ClassType>::execute() if (_process->waitForStarted(-1)) { - connect(_process.get(), QOverload<int, QProcess::ExitStatus>::of(&QProcess::finished), this, &ExternalClassifierTask<ClassType>::processFinished); - connect(_process.get(), &QProcess::readyReadStandardOutput, this, &ExternalClassifierTask<ClassType>::outputAvailable); + this->connect(_process.get(), QOverload<int, QProcess::ExitStatus>::of(&QProcess::finished), this, &ExternalClassifierTask<BaseType, ClassType>::processFinished); + this->connect(_process.get(), &QProcess::readyReadStandardOutput, this, &ExternalClassifierTask<BaseType, ClassType>::outputAvailable); } else reportError("Unable to start the external process"); @@ -100,47 +100,47 @@ void ExternalClassifierTask<ClassType>::execute() // Begin the task's work initializeTask(); - MachineLearningTask::execute(); + base_type::execute(); } -template<typename ClassType> -void ExternalClassifierTask<ClassType>::update() +template<typename BaseType, typename ClassType> +void ExternalClassifierTask<BaseType, ClassType>::update() { _cmdInterface.process(); } -template<typename ClassType> -void ExternalClassifierTask<ClassType>::stop() +template<typename BaseType, typename ClassType> +void ExternalClassifierTask<BaseType, ClassType>::stop() { sendControlMessage(ControlMessage::ControlType::Abort); shutdownTask(); - MachineLearningTask::stop(); + base_type::stop(); } -template<typename ClassType> -void ExternalClassifierTask<ClassType>::finish(bool succeeded) +template<typename BaseType, typename ClassType> +void ExternalClassifierTask<BaseType, ClassType>::finish(bool succeeded) { sendControlMessage(succeeded ? ControlMessage::ControlType::Finish : ControlMessage::ControlType::Error); shutdownTask(); } -template<typename ClassType> -void ExternalClassifierTask<ClassType>::initializeTask() +template<typename BaseType, typename ClassType> +void ExternalClassifierTask<BaseType, ClassType>::initializeTask() { - changeTaskState(TaskState::Initializing); + this->changeTaskState(MachineLearningTask::TaskState::Initializing); // Start the network client _cmdInterface.startClient("localhost", _commandPort, _controlPort); - changeTaskState(TaskState::AwaitingConnection, "Waiting for connection..."); + this->changeTaskState(MachineLearningTask::TaskState::AwaitingConnection, "Waiting for connection..."); sendGrinderInfoMessage(); } -template<typename ClassType> -void ExternalClassifierTask<ClassType>::shutdownTask(bool sendShutdownMsg) +template<typename BaseType, typename ClassType> +void ExternalClassifierTask<BaseType, ClassType>::shutdownTask(bool sendShutdownMsg) { - changeTaskState(TaskState::Shutdown, "Shutting down..."); + this->changeTaskState(MachineLearningTask::TaskState::Shutdown, "Shutting down..."); if (sendShutdownMsg) { @@ -165,8 +165,8 @@ void ExternalClassifierTask<ClassType>::shutdownTask(bool sendShutdownMsg) cleanupTask(); } -template<typename ClassType> -void ExternalClassifierTask<ClassType>::reportError(QString error, const CommandInterfaceMessage* message) +template<typename BaseType, typename ClassType> +void ExternalClassifierTask<BaseType, ClassType>::reportError(QString error, const CommandInterfaceMessage* message) { if (message) { @@ -176,38 +176,38 @@ void ExternalClassifierTask<ClassType>::reportError(QString error, const Command error += QString{" (%1)"}.arg(messageError); } - addLogMessage(error); + this->addLogMessage(error); shutdownTask(); - finishTask(false); + this->finishTask(false); } -template<typename ClassType> -void ExternalClassifierTask<ClassType>::reportUnexpectedMessage(const CommandInterfaceMessage* message) +template<typename BaseType, typename ClassType> +void ExternalClassifierTask<BaseType, ClassType>::reportUnexpectedMessage(const CommandInterfaceMessage* message) { - reportError(QString{"Unexpected message received (%1) during state %2"}.arg(message->getCommand()).arg(_taskState)); + this->reportError(QString{"Unexpected message received (%1) during state %2"}.arg(message->getCommand()).arg(this->_taskState)); } -template<typename ClassType> -void ExternalClassifierTask<ClassType>::sendControlMessage(ControlMessage::ControlType type) +template<typename BaseType, typename ClassType> +void ExternalClassifierTask<BaseType, ClassType>::sendControlMessage(ControlMessage::ControlType type) { auto control = this->template createMessage<ControlMessage>(type); _cmdInterface.sendMessage(*control); } -template<typename ClassType> -void ExternalClassifierTask<ClassType>::sendGrinderInfoMessage() +template<typename BaseType, typename ClassType> +void ExternalClassifierTask<BaseType, ClassType>::sendGrinderInfoMessage() { auto info = this->template createMessage<GrinderInfoMessage>(_userData); _cmdInterface.sendMessage(*info); } -template<typename ClassType> -std::unique_ptr<CommandInterfaceMessage> ExternalClassifierTask<ClassType>::handleHelloMessage(CommandInterfaceMessage* message) +template<typename BaseType, typename ClassType> +std::unique_ptr<CommandInterfaceMessage> ExternalClassifierTask<BaseType, ClassType>::handleHelloMessage(CommandInterfaceMessage* message) { // Got a Hello from a client, so a connection has been established auto helloMessage = this->template castMessage<HelloMessage>(message); - changeTaskState(TaskState::Connected, QString{"Connected to '%1'"}.arg(helloMessage->getInfo())); + this->changeTaskState(MachineLearningTask::TaskState::Connected, QString{"Connected to '%1'"}.arg(helloMessage->getInfo())); // Let the task-specific implementation handle the rest from here on connectionReady(); @@ -215,22 +215,22 @@ std::unique_ptr<CommandInterfaceMessage> ExternalClassifierTask<ClassType>::hand return nullptr; } -template<typename ClassType> -std::unique_ptr<CommandInterfaceMessage> ExternalClassifierTask<ClassType>::handleStatusMessage(CommandInterfaceMessage* message) +template<typename BaseType, typename ClassType> +std::unique_ptr<CommandInterfaceMessage> ExternalClassifierTask<BaseType, ClassType>::handleStatusMessage(CommandInterfaceMessage* message) { auto statusMessage = this->template castMessage<StatusMessage>(message); if (statusMessage->hasContent(StatusMessage::ContentFlag::LogMessage)) - addLogMessage(statusMessage->getLogMessage()); + this->addLogMessage(statusMessage->getLogMessage()); if (statusMessage->hasContent(StatusMessage::ContentFlag::Progress)) - setProgress(statusMessage->getProgress()); + this->setProgress(statusMessage->getProgress()); return nullptr; } -template<typename ClassType> -void ExternalClassifierTask<ClassType>::processFinished(int exitCode, QProcess::ExitStatus exitStatus) +template<typename BaseType, typename ClassType> +void ExternalClassifierTask<BaseType, ClassType>::processFinished(int exitCode, QProcess::ExitStatus exitStatus) { if (exitStatus == QProcess::NormalExit) { @@ -242,14 +242,14 @@ void ExternalClassifierTask<ClassType>::processFinished(int exitCode, QProcess:: _process = nullptr; } -template<typename ClassType> -void ExternalClassifierTask<ClassType>::outputAvailable() +template<typename BaseType, typename ClassType> +void ExternalClassifierTask<BaseType, ClassType>::outputAvailable() { if (_process) { QString output = _process->readAllStandardOutput().toStdString().data(); for (auto line : output.split("\n")) - addLogMessage("> " + line, false); + this->addLogMessage("> " + line, false); } } diff --git a/Grinder/ml/external/tasks/ExternalClassifierTrainingTask.cpp b/Grinder/ml/external/tasks/ExternalClassifierTrainingTask.cpp index d53bf1768d0e65d411f44dac1b3e8ba6c8a566eb..86283ca36eb82ca5b294d2b4eda88ec8e471cc39 100644 --- a/Grinder/ml/external/tasks/ExternalClassifierTrainingTask.cpp +++ b/Grinder/ml/external/tasks/ExternalClassifierTrainingTask.cpp @@ -28,48 +28,6 @@ void ExternalClassifierTrainingTask::registerMessageHandlers() registerMessageHandler<AckMessage>(AckMessage::Command, &ExternalClassifierTrainingTask::handleAckMessage); } -void ExternalClassifierTrainingTask::processEngineStart(EngineExecutionContext& ctx, const MachineLearningTaskData& data) -{ - ExternalClassifierTask::processEngineStart(ctx, data); - - // Create the HDF5 file in the output directory - _h5Export = std::make_unique<HDF5Export>(getTrainingDataFile()); - _h5Export->initExport(data.imageTiler.getTileSize(), ctx.imageReferences().size() * data.imageTiler.tiles().size(), 1, HDF5Export::ExportFlag::ExportTags|HDF5Export::ExportFlag::MergeTags); -} - -void ExternalClassifierTrainingTask::processEnginePass(EngineExecutionContext& ctx, const MachineLearningTaskData& data) -{ - ExternalClassifierTask::processEnginePass(ctx, data); - - if (!_h5Export) - { - reportError("The HDF5 file isn't ready for exporting"); - return; - } - - std::vector<cv::Mat> imageDataTiles = data.imageTiler.tile(data.imageData); - std::vector<cv::Mat> imageTagsDataTiles; - - if (!data.imageTagsData.empty()) - imageTagsDataTiles = data.imageTiler.tile(data.imageTagsData); - - for (unsigned int i = 0; i < imageDataTiles.size(); ++i) - { - if (!imageTagsDataTiles.empty()) - _h5Export->exportImageEx(imageDataTiles[i], {imageTagsDataTiles[i]}); - else - _h5Export->exportImage(imageDataTiles[i]); - } -} - -void ExternalClassifierTrainingTask::processEngineEnd(EngineExecutionContext& ctx, const MachineLearningTaskData& data) -{ - ExternalClassifierTask::processEngineEnd(ctx, data); - - // We no longer need the HDF5 export - _h5Export = nullptr; -} - void ExternalClassifierTrainingTask::connectionReady() { // The task has been fully initialized @@ -86,6 +44,16 @@ void ExternalClassifierTrainingTask::cleanupTask() const QFile::remove(getTrainingDataFile()); } +QString ExternalClassifierTrainingTask::getTrainingDataFile() const +{ + return FileUtils::getTemporaryFileName(FILE_TRAINING_DATA_HDF5); +} + +HDF5Export::ExportFlags ExternalClassifierTrainingTask::getTrainingDataFileFlags() const +{ + return HDF5Export::ExportFlag::ExportTags|HDF5Export::ExportFlag::MergeTags; +} + void ExternalClassifierTrainingTask::sendStartTrainingMessage() { if (isProcessRunning()) @@ -147,8 +115,3 @@ void ExternalClassifierTrainingTask::handleStartTrainingAck(const AckMessage* ac else reportUnexpectedMessage(ackMessage); } - -QString ExternalClassifierTrainingTask::getTrainingDataFile() const -{ - return FileUtils::getTemporaryFileName(FILE_TRAINING_DATA_HDF5); -} diff --git a/Grinder/ml/external/tasks/ExternalClassifierTrainingTask.h b/Grinder/ml/external/tasks/ExternalClassifierTrainingTask.h index 8a7b4fe78ce5df4493555df43c2b266593095321..449b08f1f0fa91169f2fb594f34c3db3191cadfd 100644 --- a/Grinder/ml/external/tasks/ExternalClassifierTrainingTask.h +++ b/Grinder/ml/external/tasks/ExternalClassifierTrainingTask.h @@ -7,11 +7,11 @@ #define EXTERNALCLASSIFIERTRAININGTASK_H #include "ExternalClassifierTask.h" -#include "project/exporters/HDF5Export.h" +#include "ml/tasks/MachineLearningTrainingTask.h" namespace grndr { - class ExternalClassifierTrainingTask : public ExternalClassifierTask<ExternalClassifierTrainingTask> + class ExternalClassifierTrainingTask : public ExternalClassifierTask<MachineLearningTrainingTask, ExternalClassifierTrainingTask> { Q_OBJECT @@ -24,17 +24,18 @@ namespace grndr public: virtual void registerMessageHandlers() override; - public: - virtual void processEngineStart(EngineExecutionContext& ctx, const MachineLearningTaskData& data) override; - virtual void processEnginePass(EngineExecutionContext& ctx, const MachineLearningTaskData& data) override; - virtual void processEngineEnd(EngineExecutionContext& ctx, const MachineLearningTaskData& data) override; - protected: virtual void connectionReady() override; protected: + virtual void taskError(QString error) override { reportError(error); } + virtual void cleanupTask() const override; + protected: + virtual QString getTrainingDataFile() const override; + virtual HDF5Export::ExportFlags getTrainingDataFileFlags() const override; + private: void sendStartTrainingMessage(); @@ -44,18 +45,12 @@ namespace grndr void handleStartTrainingAck(const AckMessage* ackMessage); - private: - QString getTrainingDataFile() const; - private: enum TrainingTaskState { StartTraining = TypeSpecificBase, Training, }; - - private: - std::unique_ptr<HDF5Export> _h5Export; }; } diff --git a/Grinder/ml/tasks/MachineLearningInferenceTask.cpp b/Grinder/ml/tasks/MachineLearningInferenceTask.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b8b3e3acff594be56535c1f095631eeeebc4efb1 --- /dev/null +++ b/Grinder/ml/tasks/MachineLearningInferenceTask.cpp @@ -0,0 +1,53 @@ +/****************************************************************************** + * File: MachineLearningInferenceTask.cpp + * Date: 08.10.2019 + *****************************************************************************/ + +#include "Grinder.h" +#include "MachineLearningInferenceTask.h" +#include "project/ImageReference.h" +#include "engine/EngineExecutionContext.h" +#include "image/ImageTags.h" +#include "task/TaskExceptions.h" + +void MachineLearningInferenceTask::inferImage(const cv::Mat& imageData, const ImageReference* imageRef) +{ + addLogMessage(QString{"\tPerforming inference on '%1'..."}.arg(imageRef->getImageFileName())); + + if (imageData.empty()) + throw TaskException{this, _EXCPT(QString{"Failed to get data for '%1'"}.arg(imageRef->getImageFilePath()))}; + + // Fill the image tile queue + _tileQueue.beginQueue(_imageTiler.tile(imageData)); + + // Perform the actual inference + performInference(); +} + +void MachineLearningInferenceTask::processEngineEnd(EngineExecutionContext& ctx, const MachineLearningTaskData& data) +{ + MachineLearningTask::processEngineEnd(ctx, data); + + if (!ctx.wasAborted() && completeTaskOnEnd()) + { + setProgress(1.0f); + addLogMessage("All images have been processed"); + } + + // The inference has finished, so finish the task + finishInference(); +} + +void MachineLearningInferenceTask::finishInferenceResult() +{ + ImageInferenceResults results; + + for (unsigned int tagIndex = 0; tagIndex < _inputImageTags->tags().size(); ++tagIndex) + { + auto imageTag = _inputImageTags->tags().at(tagIndex).get(); + results[imageTag] = _imageTiler.combine(_tileQueue.getResultTiles(imageTag)); + } + + // Inform about the arrived results via an event + emit inferImageFinished(results); +} diff --git a/Grinder/ml/tasks/MachineLearningInferenceTask.h b/Grinder/ml/tasks/MachineLearningInferenceTask.h new file mode 100644 index 0000000000000000000000000000000000000000..278b9af980021a11dcb21fa71afff01968ff008d --- /dev/null +++ b/Grinder/ml/tasks/MachineLearningInferenceTask.h @@ -0,0 +1,42 @@ +/****************************************************************************** + * File: MachineLearningInferenceTask.h + * Date: 08.10.2019 + *****************************************************************************/ + +#ifndef MACHINELEARNINGINFERENCETASK_H +#define MACHINELEARNINGINFERENCETASK_H + +#include "MachineLearningTask.h" +#include "InferenceTileQueue.h" + +namespace grndr +{ + class MachineLearningInferenceTask : public MachineLearningTask + { + Q_OBJECT + + public: + using MachineLearningTask::MachineLearningTask; + + public: + virtual void inferImage(const cv::Mat& imageData, const ImageReference* imageRef) override; + + public: + virtual void processEngineEnd(EngineExecutionContext& ctx, const MachineLearningTaskData& data) override; + + protected: + virtual void performInference() = 0; + virtual void finishInference() = 0; + + virtual bool completeTaskOnEnd() const { return true; } + + protected: + void finishInferenceResult(); + + protected: + InferenceTileQueue _tileQueue; + unsigned int _currentTileIndex{0}; + }; +} + +#endif diff --git a/Grinder/ml/tasks/MachineLearningTask.h b/Grinder/ml/tasks/MachineLearningTask.h index 947517a38c498a575d27f832da08aebdb9e09939..54dda90fa79d2eb06f49347b528029a6c82ddbca 100644 --- a/Grinder/ml/tasks/MachineLearningTask.h +++ b/Grinder/ml/tasks/MachineLearningTask.h @@ -73,6 +73,9 @@ namespace grndr protected: virtual void verifyTask() const override; + protected: + virtual void taskError(QString error) { Q_UNUSED(error); } + protected: QString getFullStateFile() const; diff --git a/Grinder/ml/tasks/MachineLearningTrainingTask.cpp b/Grinder/ml/tasks/MachineLearningTrainingTask.cpp new file mode 100644 index 0000000000000000000000000000000000000000..46ba4a071596e782af456b4bb885e4940f776904 --- /dev/null +++ b/Grinder/ml/tasks/MachineLearningTrainingTask.cpp @@ -0,0 +1,50 @@ +/****************************************************************************** + * File: MachineLearningTrainingTask.cpp + * Date: 08.10.2019 + *****************************************************************************/ + +#include "Grinder.h" +#include "MachineLearningTrainingTask.h" +#include "engine/EngineExecutionContext.h" + +void MachineLearningTrainingTask::processEngineStart(EngineExecutionContext& ctx, const MachineLearningTaskData& data) +{ + MachineLearningTask::processEngineStart(ctx, data); + + // Create the HDF5 file in the output directory + _h5Export = std::make_unique<HDF5Export>(getTrainingDataFile()); + _h5Export->initExport(data.imageTiler.getTileSize(), ctx.imageReferences().size() * data.imageTiler.tiles().size(), 1, getTrainingDataFileFlags()); +} + +void MachineLearningTrainingTask::processEnginePass(EngineExecutionContext& ctx, const MachineLearningTaskData& data) +{ + MachineLearningTask::processEnginePass(ctx, data); + + if (!_h5Export) + { + taskError("The HDF5 file isn't ready for exporting"); + return; + } + + std::vector<cv::Mat> imageDataTiles = data.imageTiler.tile(data.imageData); + std::vector<cv::Mat> imageTagsDataTiles; + + if (!data.imageTagsData.empty()) + imageTagsDataTiles = data.imageTiler.tile(data.imageTagsData); + + for (unsigned int i = 0; i < imageDataTiles.size(); ++i) + { + if (!imageTagsDataTiles.empty()) + _h5Export->exportImageEx(imageDataTiles[i], {imageTagsDataTiles[i]}); + else + _h5Export->exportImage(imageDataTiles[i]); + } +} + +void MachineLearningTrainingTask::processEngineEnd(EngineExecutionContext& ctx, const MachineLearningTaskData& data) +{ + MachineLearningTask::processEngineEnd(ctx, data); + + // We no longer need the HDF5 export + _h5Export = nullptr; +} diff --git a/Grinder/ml/tasks/MachineLearningTrainingTask.h b/Grinder/ml/tasks/MachineLearningTrainingTask.h new file mode 100644 index 0000000000000000000000000000000000000000..52cd5a90a555cefd7de3af3349f2cf0d8c42c50a --- /dev/null +++ b/Grinder/ml/tasks/MachineLearningTrainingTask.h @@ -0,0 +1,35 @@ +/****************************************************************************** + * File: MachineLearningTrainingTask.h + * Date: 08.10.2019 + *****************************************************************************/ + +#ifndef MACHINELEARNINGTRAININGTASK_H +#define MACHINELEARNINGTRAININGTASK_H + +#include "MachineLearningTask.h" +#include "project/exporters/HDF5Export.h" + +namespace grndr +{ + class MachineLearningTrainingTask : public MachineLearningTask + { + Q_OBJECT + + public: + using MachineLearningTask::MachineLearningTask; + + public: + virtual void processEngineStart(EngineExecutionContext& ctx, const MachineLearningTaskData& data) override; + virtual void processEnginePass(EngineExecutionContext& ctx, const MachineLearningTaskData& data) override; + virtual void processEngineEnd(EngineExecutionContext& ctx, const MachineLearningTaskData& data) override; + + protected: + virtual QString getTrainingDataFile() const = 0; + virtual HDF5Export::ExportFlags getTrainingDataFileFlags() const = 0; + + protected: + std::unique_ptr<HDF5Export> _h5Export; + }; +} + +#endif