From 476a5c6b0466b028355e520cf6b23fda07e2c18c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20M=C3=BCller?= <d_muel20@uni-muenster.de> Date: Wed, 14 Aug 2019 13:51:44 +0200 Subject: [PATCH] * Continued work on ML integration --- Grinder/Grinder.pro | 22 +- Grinder/Version.h | 4 +- Grinder/controller/TaskController.cpp | 334 +++++++++--------- Grinder/controller/TaskController.h | 121 +++---- Grinder/engine/ProcessorBase.cpp | 14 +- Grinder/engine/ProcessorBase.h | 119 ++++--- Grinder/engine/ProcessorBase.impl.h | 75 ++-- .../engine/processors/GrabCutProcessor.cpp | 16 +- Grinder/ml/MachineLearningMethod.cpp | 7 - Grinder/ml/MachineLearningMethod.h | 5 +- Grinder/ml/MachineLearningMethod.impl.h | 4 +- Grinder/ml/MachineLearningMethodBase.cpp | 6 + Grinder/ml/MachineLearningMethodBase.h | 14 +- Grinder/ml/MachineLearningTaskSpawner.h | 25 +- Grinder/ml/MachineLearningTaskSpawner.impl.h | 43 ++- Grinder/ml/MachineLearningTaskSpawnerBase.cpp | 6 - Grinder/ml/MachineLearningTaskSpawnerBase.h | 8 +- .../ml/barista/BaristaClassifierMethod.cpp | 5 + Grinder/ml/barista/BaristaClassifierMethod.h | 3 + .../barista/BaristaClassifierTaskSpawner.cpp | 20 ++ .../ml/barista/BaristaClassifierTaskSpawner.h | 16 +- .../BaristaClassifierTaskSpawner.impl.h | 19 + .../barista/blocks/BaristaClassifierBlock.cpp | 6 +- .../barista/blocks/BaristaClassifierBlock.h | 4 +- Grinder/ml/blocks/MachineLearningBlock.cpp | 19 + Grinder/ml/blocks/MachineLearningBlock.h | 44 +-- .../ml/blocks/MachineLearningMethodBlock.h | 70 ++++ ...pl.h => MachineLearningMethodBlock.impl.h} | 22 +- Grinder/ml/blocks/TrainingBlock.cpp | 31 ++ Grinder/ml/blocks/TrainingBlock.h | 42 +++ .../MachineLearningMethodProcessor.h | 38 ++ .../MachineLearningMethodProcessor.impl.h | 35 ++ .../ml/processors/MachineLearningProcessor.h | 30 +- .../MachineLearningProcessor.impl.h | 65 +++- Grinder/ml/processors/TrainingProcessor.cpp | 29 ++ Grinder/ml/processors/TrainingProcessor.h | 24 ++ Grinder/pipeline/BlockCatalog.cpp | 2 + Grinder/pipeline/BlockType.cpp | 2 + Grinder/pipeline/BlockType.h | 2 + Grinder/pipeline/PortType.cpp | 1 + Grinder/pipeline/PortType.h | 1 + 41 files changed, 916 insertions(+), 437 deletions(-) delete mode 100644 Grinder/ml/MachineLearningMethod.cpp create mode 100644 Grinder/ml/barista/BaristaClassifierTaskSpawner.impl.h create mode 100644 Grinder/ml/blocks/MachineLearningBlock.cpp create mode 100644 Grinder/ml/blocks/MachineLearningMethodBlock.h rename Grinder/ml/blocks/{MachineLearningBlock.impl.h => MachineLearningMethodBlock.impl.h} (62%) create mode 100644 Grinder/ml/blocks/TrainingBlock.cpp create mode 100644 Grinder/ml/blocks/TrainingBlock.h create mode 100644 Grinder/ml/processors/MachineLearningMethodProcessor.h create mode 100644 Grinder/ml/processors/MachineLearningMethodProcessor.impl.h create mode 100644 Grinder/ml/processors/TrainingProcessor.cpp create mode 100644 Grinder/ml/processors/TrainingProcessor.h diff --git a/Grinder/Grinder.pro b/Grinder/Grinder.pro index 35f0ec6..338272e 100644 --- a/Grinder/Grinder.pro +++ b/Grinder/Grinder.pro @@ -441,7 +441,6 @@ SOURCES += \ ui/properties/editors/PathPropertyEditor.cpp \ ui/dlg/BrowseDialog.cpp \ engine/data/DataType.cpp \ - ml/MachineLearningMethod.cpp \ ml/MachineLearningMethodBase.cpp \ ml/MachineLearningConfiguration.cpp \ ml/barista/BaristaClassifierConfiguration.cpp \ @@ -453,7 +452,10 @@ SOURCES += \ ml/barista/properties/BaristaNetworkProperty.cpp \ ui/barista/editors/BaristaNetworkPropertyEditor.cpp \ ml/properties/MachineLearningStateProperty.cpp \ - ui/ml/editors/MachineLearningStatePropertyEditor.cpp + ui/ml/editors/MachineLearningStatePropertyEditor.cpp \ + ml/blocks/MachineLearningBlock.cpp \ + ml/blocks/TrainingBlock.cpp \ + ml/processors/TrainingProcessor.cpp HEADERS += \ ui/mainwnd/GrinderWindow.h \ @@ -968,15 +970,21 @@ HEADERS += \ ml/MachineLearningTaskSpawnerBase.h \ ml/MachineLearningTaskSpawner.h \ ml/MachineLearningTaskSpawner.impl.h \ - ml/blocks/MachineLearningBlock.h \ - ml/blocks/MachineLearningBlock.impl.h \ - ml/processors/MachineLearningProcessor.h \ - ml/processors/MachineLearningProcessor.impl.h \ ml/barista/blocks/BaristaClassifierBlock.h \ ml/barista/properties/BaristaNetworkProperty.h \ ui/barista/editors/BaristaNetworkPropertyEditor.h \ ml/properties/MachineLearningStateProperty.h \ - ui/ml/editors/MachineLearningStatePropertyEditor.h + ui/ml/editors/MachineLearningStatePropertyEditor.h \ + ml/blocks/MachineLearningMethodBlock.h \ + ml/blocks/MachineLearningMethodBlock.impl.h \ + ml/blocks/MachineLearningBlock.h \ + ml/blocks/TrainingBlock.h \ + ml/processors/TrainingProcessor.h \ + ml/processors/MachineLearningMethodProcessor.h \ + ml/processors/MachineLearningMethodProcessor.impl.h \ + ml/processors/MachineLearningProcessor.h \ + ml/processors/MachineLearningProcessor.impl.h \ + ml/barista/BaristaClassifierTaskSpawner.impl.h FORMS += \ ui/mainwnd/GrinderWindow.ui \ diff --git a/Grinder/Version.h b/Grinder/Version.h index 62d30ff..f4f176d 100644 --- a/Grinder/Version.h +++ b/Grinder/Version.h @@ -10,14 +10,14 @@ #define GRNDR_INFO_TITLE "Grinder" #define GRNDR_INFO_COPYRIGHT "Copyright (c) WWU Muenster" -#define GRNDR_INFO_DATE "13.8.2019" +#define GRNDR_INFO_DATE "14.8.2019" #define GRNDR_INFO_COMPANY "WWU Muenster" #define GRNDR_INFO_WEBSITE "http://www.uni-muenster.de" #define GRNDR_VERSION_MAJOR 0 #define GRNDR_VERSION_MINOR 15 #define GRNDR_VERSION_REVISION 0 -#define GRNDR_VERSION_BUILD 374 +#define GRNDR_VERSION_BUILD 375 namespace grndr { diff --git a/Grinder/controller/TaskController.cpp b/Grinder/controller/TaskController.cpp index d659d7e..ccd1893 100644 --- a/Grinder/controller/TaskController.cpp +++ b/Grinder/controller/TaskController.cpp @@ -1,164 +1,170 @@ -/****************************************************************************** - * File: TaskController.cpp - * Date: 01.11.2018 - *****************************************************************************/ - -#include "Grinder.h" -#include "TaskController.h" -#include "task/TaskPool.h" -#include "task/TaskExceptions.h" -#include "ui/task/TaskPoolWidget.h" -#include "ui/task/ConfigureTaskDialog.h" -#include "util/StringUtils.h" - -TaskController::TaskController(TaskPool* taskPool) : GenericController("Task pool"), - _taskPool{taskPool} -{ - if (!taskPool) - throw std::invalid_argument{_EXCPT("taskPool may not be null")}; - - // Listen for task pool events - connect(_taskPool, &TaskPool::taskCreated, this, &TaskController::taskCreated); - connect(_taskPool, &TaskPool::taskRemoved, this, &TaskController::taskRemoved); - - // Periodically update all tasks - _updateTimer.setInterval(1000); - _updateTimer.start(); - - connect(&_updateTimer, &QTimer::timeout, this, &TaskController::updateTasks); -} - -void TaskController::assignUiComponents(TaskPoolWidget* taskPoolWidget) -{ - if (!taskPoolWidget) - throw std::invalid_argument{_EXCPT("taskPoolWidget may not be null")}; - - _taskPoolWidget = taskPoolWidget; -} - -std::shared_ptr<Task> TaskController::createTask(TaskType type, QString name) const -{ - return callControllerFunction("Creating a new task", [this](TaskType type, QString name) { - return _taskPool->createTask(type, name); - }, type, name); -} - -void TaskController::removeTask(const Task* task) const -{ - if (task) - { - callControllerFunction("Removing a task", [this](const Task* task) { - _taskPool->removeTask(task); - return true; - }, task); - } -} - -void TaskController::removeAllTasks() -{ - if (_taskPool) - { - auto tasks = _taskPool->tasks(); - - for (const auto& task : tasks) - removeTask(task.get()); - } -} - -void TaskController::startTask(Task* task) const -{ - callControllerFunction("Starting a task", [](Task* task) { - task->startTask(); - return true; - }, task); -} - -void TaskController::pauseTask(Task* task, bool pause) const -{ - callControllerFunction("Starting a task", [pause](Task* task) { - task->pauseTask(pause); - return true; - }, task); -} - -void TaskController::refreshTask(Task* task) const -{ - callControllerFunction("Refreshing a task", [](Task* task) { - task->refreshTask(); - return true; - }, task); -} - -void TaskController::stopTask(Task* task) const -{ - callControllerFunction("Stopping a task", [](Task* task) { - task->stopTask(); - return true; - }, task); -} - -void TaskController::newTask(TaskType type) const -{ - callControllerFunction("Adding a new task", [this](TaskType type) { - QString newTaskName = StringUtils::generateUniqueItemName(_taskPool->tasks(), type + " task", &Task::getName); - - if (auto task = _taskPool->createTask(type, newTaskName, false)) - { - ConfigureTaskDialog dlg{task.get(), true}; - - if (dlg.exec() == QDialog::Accepted) - _taskPool->addTask(task); - } - - return true; - }, type); -} - -void TaskController::configureTask(Task* task) const -{ - callControllerFunction("Configuring a task", [](Task* task) { - if (!task->isRunning()) - { - ConfigureTaskDialog dlg{task, false}; - dlg.exec(); - } - - return true; - }, task); -} - -void TaskController::validateTaskName(QString name, const Task* task) const -{ - auto taskPool = task ? task->taskPool() : _taskPool; - - // Name must be non-empty and unique - if (!name.isEmpty()) - { - auto existingTask = taskPool->tasks().selectByName(name).get(); - - if (existingTask && existingTask != task) - throw TaskPoolException{taskPool, _EXCPT(QString{"A task with the name '%1' already exists"}.arg(name))}; - } - else - throw TaskPoolException{taskPool, _EXCPT("The task name may not be empty")}; -} - -void TaskController::taskCreated(const std::shared_ptr<Task>& task) const -{ - if (_taskPoolWidget) - _taskPoolWidget->addTask(task); -} - -void TaskController::taskRemoved(const std::shared_ptr<Task>& task) const -{ - if (_taskPoolWidget) - _taskPoolWidget->removeTask(task); -} - -void TaskController::updateTasks() -{ - callControllerFunction("Updating tasks", [this]() { - _taskPool->updateTasks(); - return true; - }); -} +/****************************************************************************** + * File: TaskController.cpp + * Date: 01.11.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "TaskController.h" +#include "task/TaskPool.h" +#include "task/TaskExceptions.h" +#include "ui/task/TaskPoolWidget.h" +#include "ui/task/ConfigureTaskDialog.h" +#include "util/StringUtils.h" + +TaskController::TaskController(TaskPool* taskPool) : GenericController("Task pool"), + _taskPool{taskPool} +{ + if (!taskPool) + throw std::invalid_argument{_EXCPT("taskPool may not be null")}; + + // Listen for task pool events + connect(_taskPool, &TaskPool::taskCreated, this, &TaskController::taskCreated); + connect(_taskPool, &TaskPool::taskRemoved, this, &TaskController::taskRemoved); + + // Periodically update all tasks + _updateTimer.setInterval(1000); + _updateTimer.start(); + + connect(&_updateTimer, &QTimer::timeout, this, &TaskController::updateTasks); +} + +void TaskController::assignUiComponents(TaskPoolWidget* taskPoolWidget) +{ + if (!taskPoolWidget) + throw std::invalid_argument{_EXCPT("taskPoolWidget may not be null")}; + + _taskPoolWidget = taskPoolWidget; +} + +std::shared_ptr<Task> TaskController::createTask(TaskType type, QString name, bool addToPool) const +{ + return callControllerFunction("Creating a new task", [this](TaskType type, QString name, bool addToPool) { + return _taskPool->createTask(type, name, addToPool); + }, type, name, addToPool); +} + +void TaskController::addTask(std::shared_ptr<Task>& task) +{ + if (task) + _taskPool->addTask(task); +} + +void TaskController::removeTask(const Task* task) const +{ + if (task) + { + callControllerFunction("Removing a task", [this](const Task* task) { + _taskPool->removeTask(task); + return true; + }, task); + } +} + +void TaskController::removeAllTasks() +{ + if (_taskPool) + { + auto tasks = _taskPool->tasks(); + + for (const auto& task : tasks) + removeTask(task.get()); + } +} + +void TaskController::startTask(Task* task) const +{ + callControllerFunction("Starting a task", [](Task* task) { + task->startTask(); + return true; + }, task); +} + +void TaskController::pauseTask(Task* task, bool pause) const +{ + callControllerFunction("Starting a task", [pause](Task* task) { + task->pauseTask(pause); + return true; + }, task); +} + +void TaskController::refreshTask(Task* task) const +{ + callControllerFunction("Refreshing a task", [](Task* task) { + task->refreshTask(); + return true; + }, task); +} + +void TaskController::stopTask(Task* task) const +{ + callControllerFunction("Stopping a task", [](Task* task) { + task->stopTask(); + return true; + }, task); +} + +void TaskController::newTask(TaskType type) const +{ + callControllerFunction("Adding a new task", [this](TaskType type) { + QString newTaskName = StringUtils::generateUniqueItemName(_taskPool->tasks(), type + " task", &Task::getName); + + if (auto task = _taskPool->createTask(type, newTaskName, false)) + { + ConfigureTaskDialog dlg{task.get(), true}; + + if (dlg.exec() == QDialog::Accepted) + _taskPool->addTask(task); + } + + return true; + }, type); +} + +void TaskController::configureTask(Task* task) const +{ + callControllerFunction("Configuring a task", [](Task* task) { + if (!task->isRunning()) + { + ConfigureTaskDialog dlg{task, false}; + dlg.exec(); + } + + return true; + }, task); +} + +void TaskController::validateTaskName(QString name, const Task* task) const +{ + auto taskPool = task ? task->taskPool() : _taskPool; + + // Name must be non-empty and unique + if (!name.isEmpty()) + { + auto existingTask = taskPool->tasks().selectByName(name).get(); + + if (existingTask && existingTask != task) + throw TaskPoolException{taskPool, _EXCPT(QString{"A task with the name '%1' already exists"}.arg(name))}; + } + else + throw TaskPoolException{taskPool, _EXCPT("The task name may not be empty")}; +} + +void TaskController::taskCreated(const std::shared_ptr<Task>& task) const +{ + if (_taskPoolWidget) + _taskPoolWidget->addTask(task); +} + +void TaskController::taskRemoved(const std::shared_ptr<Task>& task) const +{ + if (_taskPoolWidget) + _taskPoolWidget->removeTask(task); +} + +void TaskController::updateTasks() +{ + callControllerFunction("Updating tasks", [this]() { + _taskPool->updateTasks(); + return true; + }); +} diff --git a/Grinder/controller/TaskController.h b/Grinder/controller/TaskController.h index af95bce..f733ac8 100644 --- a/Grinder/controller/TaskController.h +++ b/Grinder/controller/TaskController.h @@ -1,60 +1,61 @@ -/****************************************************************************** - * File: TaskController.h - * Date: 01.11.2018 - *****************************************************************************/ - -#ifndef TASKCONTROLLER_H -#define TASKCONTROLLER_H - -#include "GenericController.h" -#include "task/TaskType.h" - -namespace grndr -{ - class TaskPool; - class Task; - class TaskPoolWidget; - - class TaskController : public GenericController - { - Q_OBJECT - - public: - TaskController(TaskPool* taskPool); - - public: - void assignUiComponents(TaskPoolWidget* taskPoolWidget); - - public: - std::shared_ptr<Task> createTask(TaskType type, QString name = "") const; - void removeTask(const Task* task) const; - void removeAllTasks(); - - void startTask(Task* task) const; - void pauseTask(Task* task, bool pause = true) const; - void refreshTask(Task* task) const; - void stopTask(Task* task) const; - - void newTask(TaskType type) const; - void configureTask(Task* task) const; - - public: - void validateTaskName(QString name, const Task* task = nullptr) const; - - private slots: - void taskCreated(const std::shared_ptr<Task>& task) const; - void taskRemoved(const std::shared_ptr<Task>& task) const; - - void updateTasks(); - - private: - TaskPool* _taskPool{nullptr}; - - TaskPoolWidget* _taskPoolWidget{nullptr}; - - private: - QTimer _updateTimer; - }; -} - -#endif +/****************************************************************************** + * File: TaskController.h + * Date: 01.11.2018 + *****************************************************************************/ + +#ifndef TASKCONTROLLER_H +#define TASKCONTROLLER_H + +#include "GenericController.h" +#include "task/TaskType.h" + +namespace grndr +{ + class TaskPool; + class Task; + class TaskPoolWidget; + + class TaskController : public GenericController + { + Q_OBJECT + + public: + TaskController(TaskPool* taskPool); + + public: + void assignUiComponents(TaskPoolWidget* taskPoolWidget); + + public: + std::shared_ptr<Task> createTask(TaskType type, QString name = "", bool addToPool = true) const; + void addTask(std::shared_ptr<Task>& task); + void removeTask(const Task* task) const; + void removeAllTasks(); + + void startTask(Task* task) const; + void pauseTask(Task* task, bool pause = true) const; + void refreshTask(Task* task) const; + void stopTask(Task* task) const; + + void newTask(TaskType type) const; + void configureTask(Task* task) const; + + public: + void validateTaskName(QString name, const Task* task = nullptr) const; + + private slots: + void taskCreated(const std::shared_ptr<Task>& task) const; + void taskRemoved(const std::shared_ptr<Task>& task) const; + + void updateTasks(); + + private: + TaskPool* _taskPool{nullptr}; + + TaskPoolWidget* _taskPoolWidget{nullptr}; + + private: + QTimer _updateTimer; + }; +} + +#endif diff --git a/Grinder/engine/ProcessorBase.cpp b/Grinder/engine/ProcessorBase.cpp index 863240c..4bac0ec 100644 --- a/Grinder/engine/ProcessorBase.cpp +++ b/Grinder/engine/ProcessorBase.cpp @@ -56,10 +56,18 @@ DataDescriptor ProcessorBase::getBestDataDescriptor(const DataDescriptor& dataDe return DataDescriptor{}; } -DataBlob* ProcessorBase::portData(EngineExecutionContext& ctx, const Port* port, const Connection* connection, bool required, bool convert) const +DataBlob* ProcessorBase::portData(EngineExecutionContext& ctx, const Port* port, const Connection* connection, bool required, bool convert, bool persistentData) const { auto dataPort = resolveDataPort(port, connection, required); - DataBlob* data = dataPort ? ctx.data().get(dataPort) : nullptr; + DataBlob* data = nullptr; + + if (dataPort) + { + if (persistentData) + data = ctx.persistentData().get(dataPort); + else + data = ctx.data().get(dataPort); + } if (!data && required) throwProcessorException(QString{"No data could be retrieved for port '%1'"}.arg(port->getName())); @@ -100,7 +108,7 @@ cv::Mat ProcessorBase::getBypassedImage(const DataBlob* dataBlob, bool outputIsG void ProcessorBase::throwProcessorException(QString what) const { - throw ProcessorException{this, _EXCPT(QString{"%1: %2"}.arg(_block->getName()).arg(what))}; + throw ProcessorException{this, _EXCPT(what)}; } const Port* ProcessorBase::resolveDataPort(const Port* port, const Connection* connection, bool required) const diff --git a/Grinder/engine/ProcessorBase.h b/Grinder/engine/ProcessorBase.h index 526c0c2..dde06a9 100644 --- a/Grinder/engine/ProcessorBase.h +++ b/Grinder/engine/ProcessorBase.h @@ -1,58 +1,61 @@ -/****************************************************************************** - * File: ProcessorBase.h - * Date: 21.2.2018 - *****************************************************************************/ - -#ifndef PROCESSORBASE_H -#define PROCESSORBASE_H - -#include "engine/EngineExecutionContext.h" -#include "common/properties/PropertyID.h" - -namespace grndr -{ - class Engine; - class Connection; - - class ProcessorBase - { - public: - ProcessorBase(const Block* block); - - public: - const Engine* engine() const { return _engine; } - const Block* block() const { return _block; } - - public: - virtual void execute(EngineExecutionContext& ctx); - - protected: - DataDescriptor getPortDataDescriptor(const Port* port, unsigned int index = 0) const; - DataDescriptor getBestDataDescriptor(const DataDescriptor& dataDesc, const DataDescriptors& targetDescriptors) const; - - DataBlob* portData(EngineExecutionContext& ctx, const Port* port, const Connection* connection = nullptr, bool required = true, bool convert = true) const; - template<typename PropType> - PropType* portProperty(const Port* port, PropertyID propertyID, const Connection* connection = nullptr, bool required = true) const; - - protected: - bool isBlockBypassed() const; - cv::Mat getBypassedImage(const DataBlob* dataBlob, bool outputIsGrayscale = false); - - protected: - void throwProcessorException(QString what) const; - - private: - const Port* resolveDataPort(const Port* port, const Connection* connection, bool required) const; - - protected: - const Engine* _engine{nullptr}; - const Block* _block{nullptr}; - - private: - mutable std::vector<std::shared_ptr<DataBlob>> _portDataCache; - }; -} - -#include "ProcessorBase.impl.h" - -#endif +/****************************************************************************** + * File: ProcessorBase.h + * Date: 21.2.2018 + *****************************************************************************/ + +#ifndef PROCESSORBASE_H +#define PROCESSORBASE_H + +#include "pipeline/BlockType.h" +#include "engine/EngineExecutionContext.h" +#include "common/properties/PropertyID.h" + +namespace grndr +{ + class Engine; + class Connection; + + class ProcessorBase + { + public: + ProcessorBase(const Block* block); + + public: + const Engine* engine() const { return _engine; } + const Block* block() const { return _block; } + + public: + virtual void execute(EngineExecutionContext& ctx); + + protected: + DataDescriptor getPortDataDescriptor(const Port* port, unsigned int index = 0) const; + DataDescriptor getBestDataDescriptor(const DataDescriptor& dataDesc, const DataDescriptors& targetDescriptors) const; + + DataBlob* portData(EngineExecutionContext& ctx, const Port* port, const Connection* connection = nullptr, bool required = true, bool convert = true, bool persistentData = false) const; + template<typename ValueType> + ValueType portData(EngineExecutionContext& ctx, const Port* port, QString dataName, bool required = true, BlockType blockType = BlockType::Undefined, bool persistentData = false) const; + template<typename PropType> + PropType* portProperty(const Port* port, PropertyID propertyID, const Connection* connection = nullptr, bool required = true) const; + + protected: + bool isBlockBypassed() const; + cv::Mat getBypassedImage(const DataBlob* dataBlob, bool outputIsGrayscale = false); + + protected: + void throwProcessorException(QString what) const; + + private: + const Port* resolveDataPort(const Port* port, const Connection* connection, bool required) const; + + protected: + const Engine* _engine{nullptr}; + const Block* _block{nullptr}; + + private: + mutable std::vector<std::shared_ptr<DataBlob>> _portDataCache; + }; +} + +#include "ProcessorBase.impl.h" + +#endif diff --git a/Grinder/engine/ProcessorBase.impl.h b/Grinder/engine/ProcessorBase.impl.h index e6d3724..fa7f8ae 100644 --- a/Grinder/engine/ProcessorBase.impl.h +++ b/Grinder/engine/ProcessorBase.impl.h @@ -1,24 +1,51 @@ -/****************************************************************************** - * File: ProcessorBase.impl.h - * Date: 20.4.2018 - *****************************************************************************/ - -#include "Grinder.h" -#include "ProcessorBase.h" -#include "pipeline/Block.h" - -template<typename PropType> -PropType* ProcessorBase::portProperty(const Port* port, PropertyID propertyID, const Connection* connection, bool required) const -{ - if (auto dataPort = resolveDataPort(port, connection, required)) - { - auto property = _block->portProperty<PropType>(dataPort, propertyID); - - if (!property && required) - throwProcessorException(QString{"No property with ID '%2' could be retrieved for port '%1'"}.arg(port->getName()).arg(propertyID)); - - return property; - } - else - return nullptr; -} +/****************************************************************************** + * File: ProcessorBase.impl.h + * Date: 20.4.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "ProcessorBase.h" +#include "pipeline/Block.h" + +template<typename ValueType> +ValueType ProcessorBase::portData(EngineExecutionContext& ctx, const Port* port, QString dataName, bool required, BlockType blockType, bool persistentData) const +{ + auto connections = port->getConnections(port->getDirection(), blockType); + + if (connections.size() == 1) + { + auto sourcePort = connections[0]->sourcePort(); + + if (persistentData) + { + if (ctx.persistentData().contains(sourcePort, dataName)) + return ctx.persistentData().get<ValueType>(sourcePort, dataName); + } + else + { + if (ctx.data().contains(sourcePort, dataName)) + return ctx.data().get<ValueType>(sourcePort, dataName); + } + } + + if (required) + throwProcessorException(QString{"Data '%2' could be retrieved for port '%1'"}.arg(port->getName()).arg(dataName)); + + return {}; +} + +template<typename PropType> +PropType* ProcessorBase::portProperty(const Port* port, PropertyID propertyID, const Connection* connection, bool required) const +{ + if (auto dataPort = resolveDataPort(port, connection, required)) + { + auto property = _block->portProperty<PropType>(dataPort, propertyID); + + if (!property && required) + throwProcessorException(QString{"No property with ID '%2' could be retrieved for port '%1'"}.arg(port->getName()).arg(propertyID)); + + return property; + } + else + return nullptr; +} diff --git a/Grinder/engine/processors/GrabCutProcessor.cpp b/Grinder/engine/processors/GrabCutProcessor.cpp index 99d13cf..abde6ea 100644 --- a/Grinder/engine/processors/GrabCutProcessor.cpp +++ b/Grinder/engine/processors/GrabCutProcessor.cpp @@ -37,21 +37,19 @@ void GrabCutProcessor::execute(EngineExecutionContext& ctx) { if (!isBlockBypassed()) { - cv::Mat mask; - cv::Mat backgroundModel; - cv::Mat foregroundModel; - // Get any preceding and succeeding GrabCut blocks auto predecessors = _block->predecessorPort()->getConnections(Port::Direction::In, BlockType::GrabCut); auto successors = _block->successorPort()->getConnections(Port::Direction::Out, BlockType::GrabCut); - if (predecessors.size() == 1) // Initialize the GrabCut algorithm using the data from the predecessor - { - auto predecessor = predecessors[0]->sourcePort(); + // Initialize the GrabCut algorithm using the data from the predecessor + cv::Mat backgroundModel = portData<cv::Mat>(ctx, _block->predecessorPort(), Data_Value_BGModel, false, BlockType::GrabCut); + cv::Mat foregroundModel = portData<cv::Mat>(ctx, _block->predecessorPort(), Data_Value_FGModel, false, BlockType::GrabCut); + cv::Mat mask; + + if (predecessors.size() == 1) + { mask = generateGrabCutMask(maskBlob->getMatrix(), false); - backgroundModel = ctx.data().get<cv::Mat>(predecessor, Data_Value_BGModel); - foregroundModel = ctx.data().get<cv::Mat>(predecessor, Data_Value_FGModel); } else // Initialize the GrabCut algorithm from scratch { diff --git a/Grinder/ml/MachineLearningMethod.cpp b/Grinder/ml/MachineLearningMethod.cpp deleted file mode 100644 index 0069233..0000000 --- a/Grinder/ml/MachineLearningMethod.cpp +++ /dev/null @@ -1,7 +0,0 @@ -/****************************************************************************** - * File: MachineLearningMethod.cpp - * Date: 13.8.2019 - *****************************************************************************/ - -#include "Grinder.h" -#include "MachineLearningMethod.h" diff --git a/Grinder/ml/MachineLearningMethod.h b/Grinder/ml/MachineLearningMethod.h index 3edec2d..0f00caa 100644 --- a/Grinder/ml/MachineLearningMethod.h +++ b/Grinder/ml/MachineLearningMethod.h @@ -23,7 +23,10 @@ namespace grndr using spawner_type = SpawnerType; public: - virtual std::unique_ptr<MachineLearningTaskSpawnerBase> createTaskSpawner(InvocationMode mode) override; + using MachineLearningMethodBase::MachineLearningMethodBase; + + public: + virtual std::unique_ptr<MachineLearningTaskSpawnerBase> createTaskSpawner() const override; public: config_type& config() { return _config; } diff --git a/Grinder/ml/MachineLearningMethod.impl.h b/Grinder/ml/MachineLearningMethod.impl.h index 04ebec1..f531b65 100644 --- a/Grinder/ml/MachineLearningMethod.impl.h +++ b/Grinder/ml/MachineLearningMethod.impl.h @@ -8,7 +8,7 @@ #include "MachineLearningExceptions.h" template<typename ConfigType, typename SpawnerType> -std::unique_ptr<MachineLearningTaskSpawnerBase> MachineLearningMethod<ConfigType, SpawnerType>::createTaskSpawner(InvocationMode mode) +std::unique_ptr<MachineLearningTaskSpawnerBase> MachineLearningMethod<ConfigType, SpawnerType>::createTaskSpawner() const { - return std::make_unique<spawner_type>(mode, _config); + return std::make_unique<spawner_type>(_config); } diff --git a/Grinder/ml/MachineLearningMethodBase.cpp b/Grinder/ml/MachineLearningMethodBase.cpp index 43e3aa4..125e832 100644 --- a/Grinder/ml/MachineLearningMethodBase.cpp +++ b/Grinder/ml/MachineLearningMethodBase.cpp @@ -5,3 +5,9 @@ #include "Grinder.h" #include "MachineLearningMethodBase.h" + +MachineLearningMethodBase::MachineLearningMethodBase(QString methodName) : + _methodName{methodName} +{ + +} diff --git a/Grinder/ml/MachineLearningMethodBase.h b/Grinder/ml/MachineLearningMethodBase.h index 14d6889..6302def 100644 --- a/Grinder/ml/MachineLearningMethodBase.h +++ b/Grinder/ml/MachineLearningMethodBase.h @@ -18,17 +18,19 @@ namespace grndr Q_OBJECT public: - enum class InvocationMode - { - Training, - Inference, - }; + MachineLearningMethodBase(QString methodName); public: - virtual std::unique_ptr<MachineLearningTaskSpawnerBase> createTaskSpawner(InvocationMode mode) = 0; + virtual std::unique_ptr<MachineLearningTaskSpawnerBase> createTaskSpawner() const = 0; public: virtual QStringList getAvailableStates() const = 0; + + public: + QString getMethodName() const { return _methodName; } + + protected: + QString _methodName{""}; }; } diff --git a/Grinder/ml/MachineLearningTaskSpawner.h b/Grinder/ml/MachineLearningTaskSpawner.h index 7a0d089..b99048d 100644 --- a/Grinder/ml/MachineLearningTaskSpawner.h +++ b/Grinder/ml/MachineLearningTaskSpawner.h @@ -7,21 +7,42 @@ #define MACHINELEARNINGTASKSPAWNER_H #include "MachineLearningTaskSpawnerBase.h" +#include "task/TaskType.h" namespace grndr { class MachineLearningConfiguration; + class Task; - template<typename ConfigType> + template<typename ConfigType, typename TrainingTaskType, typename InferenceTaskType> class MachineLearningTaskSpawner : public MachineLearningTaskSpawnerBase { static_assert(std::is_base_of<MachineLearningConfiguration, ConfigType>::value, "ConfigType must be derived from MachineLearningConfiguration"); + static_assert(std::is_base_of<Task, TrainingTaskType>::value, "TrainingTaskType must be derived from Task"); + static_assert(std::is_base_of<Task, InferenceTaskType>::value, "TrainingTaskType must be derived from Task"); public: using config_type = ConfigType; + using training_task_type = TrainingTaskType; + using inference_task_type = InferenceTaskType; public: - MachineLearningTaskSpawner(MachineLearningMethodBase::InvocationMode mode, const config_type& config); + MachineLearningTaskSpawner(const config_type& config); + + public: + virtual std::shared_ptr<Task> spawnTrainingTask(QString state, QString name) const override; + virtual std::shared_ptr<Task> spawnInferenceTask(QString state, QString name) const override; + + protected: + virtual void configureTrainingTask(training_task_type* task, QString state) const = 0; + virtual void configureInferenceTask(inference_task_type* task, QString state) const = 0; + + private: + template<typename TaskType> + std::shared_ptr<TaskType> createTask(QString name) const; + + std::shared_ptr<training_task_type> createTrainingTask(QString name) const { return createTask<training_task_type>(name); } + std::shared_ptr<inference_task_type> createInferenceTask(QString name) const { return createTask<inference_task_type>(name); } protected: const config_type& _config; diff --git a/Grinder/ml/MachineLearningTaskSpawner.impl.h b/Grinder/ml/MachineLearningTaskSpawner.impl.h index 0ac03a6..8a87f14 100644 --- a/Grinder/ml/MachineLearningTaskSpawner.impl.h +++ b/Grinder/ml/MachineLearningTaskSpawner.impl.h @@ -5,10 +5,49 @@ #include "Grinder.h" #include "MachineLearningTaskSpawner.h" +#include "MachineLearningExceptions.h" +#include "core/GrinderApplication.h" -template<typename ConfigType> -MachineLearningTaskSpawner<ConfigType>::MachineLearningTaskSpawner(MachineLearningMethodBase::InvocationMode mode, const MachineLearningTaskSpawner::config_type& config) : MachineLearningTaskSpawnerBase(mode), +template<typename ConfigType, typename TrainingTaskType, typename InferenceTaskType> +MachineLearningTaskSpawner<ConfigType, TrainingTaskType, InferenceTaskType>::MachineLearningTaskSpawner(const config_type& config) : _config{config} { } + +template<typename ConfigType, typename TrainingTaskType, typename InferenceTaskType> +std::shared_ptr<Task> MachineLearningTaskSpawner<ConfigType, TrainingTaskType, InferenceTaskType>::spawnTrainingTask(QString state, QString name) const +{ + // If any of the following functions fails, an exception will be thrown + auto task = createTrainingTask(name); + configureTrainingTask(task.get(), state); + + auto sharedTask = std::dynamic_pointer_cast<Task>(task); + grinder()->taskController().addTask(sharedTask); + return sharedTask; +} + +template<typename ConfigType, typename TrainingTaskType, typename InferenceTaskType> +std::shared_ptr<Task> MachineLearningTaskSpawner<ConfigType, TrainingTaskType, InferenceTaskType>::spawnInferenceTask(QString state, QString name) const +{ + // If any of the following functions fails, an exception will be thrown + auto task = createInferenceTask(name); + configureInferenceTask(task.get(), state); + + auto sharedTask = std::dynamic_pointer_cast<Task>(task); + grinder()->taskController().addTask(sharedTask); + return sharedTask; +} + +template<typename ConfigType, typename TrainingTaskType, typename InferenceTaskType> +template<typename TaskType> +std::shared_ptr<TaskType> MachineLearningTaskSpawner<ConfigType, TrainingTaskType, InferenceTaskType>::createTask(QString name) const +{ + if (auto task = grinder()->taskController().createTask(TaskType::type_value, name, false)) + { + if (auto typedTask = std::dynamic_pointer_cast<TaskType>(task)) + return typedTask; + } + + throw MachineLearningException{_EXCPT(QString{"Unable to create a task of type '%1'"}.arg(TaskType::type_value))}; +} diff --git a/Grinder/ml/MachineLearningTaskSpawnerBase.cpp b/Grinder/ml/MachineLearningTaskSpawnerBase.cpp index 15f258b..1ac396e 100644 --- a/Grinder/ml/MachineLearningTaskSpawnerBase.cpp +++ b/Grinder/ml/MachineLearningTaskSpawnerBase.cpp @@ -5,9 +5,3 @@ #include "Grinder.h" #include "MachineLearningTaskSpawnerBase.h" - -MachineLearningTaskSpawnerBase::MachineLearningTaskSpawnerBase(MachineLearningMethodBase::InvocationMode mode) : - _invocationMode{mode} -{ - -} diff --git a/Grinder/ml/MachineLearningTaskSpawnerBase.h b/Grinder/ml/MachineLearningTaskSpawnerBase.h index ab0926e..669a46a 100644 --- a/Grinder/ml/MachineLearningTaskSpawnerBase.h +++ b/Grinder/ml/MachineLearningTaskSpawnerBase.h @@ -10,13 +10,13 @@ namespace grndr { + class Task; + class MachineLearningTaskSpawnerBase { public: - MachineLearningTaskSpawnerBase(MachineLearningMethodBase::InvocationMode mode); - - protected: - MachineLearningMethodBase::InvocationMode _invocationMode{MachineLearningMethodBase::InvocationMode::Training}; + virtual std::shared_ptr<Task> spawnTrainingTask(QString state, QString name) const = 0; + virtual std::shared_ptr<Task> spawnInferenceTask(QString state, QString name) const = 0; }; } diff --git a/Grinder/ml/barista/BaristaClassifierMethod.cpp b/Grinder/ml/barista/BaristaClassifierMethod.cpp index f73cf03..d6c8552 100644 --- a/Grinder/ml/barista/BaristaClassifierMethod.cpp +++ b/Grinder/ml/barista/BaristaClassifierMethod.cpp @@ -6,6 +6,11 @@ #include "Grinder.h" #include "BaristaClassifierMethod.h" +BaristaClassifierMethod::BaristaClassifierMethod() : MachineLearningMethod("Barista") +{ + +} + QStringList BaristaClassifierMethod::getAvailableStates() const { QString outputDir = _config.getOutputDirectory(); diff --git a/Grinder/ml/barista/BaristaClassifierMethod.h b/Grinder/ml/barista/BaristaClassifierMethod.h index 1da882e..31dba64 100644 --- a/Grinder/ml/barista/BaristaClassifierMethod.h +++ b/Grinder/ml/barista/BaristaClassifierMethod.h @@ -16,6 +16,9 @@ namespace grndr { Q_OBJECT + public: + BaristaClassifierMethod(); + public: virtual QStringList getAvailableStates() const override; }; diff --git a/Grinder/ml/barista/BaristaClassifierTaskSpawner.cpp b/Grinder/ml/barista/BaristaClassifierTaskSpawner.cpp index 3f43e21..6e6d63f 100644 --- a/Grinder/ml/barista/BaristaClassifierTaskSpawner.cpp +++ b/Grinder/ml/barista/BaristaClassifierTaskSpawner.cpp @@ -5,3 +5,23 @@ #include "Grinder.h" #include "BaristaClassifierTaskSpawner.h" + +void BaristaClassifierTaskSpawner::configureTrainingTask(BaristaTrainingTask* task, QString state) const +{ + Q_UNUSED(state); + + configureTask(task); + + // Configure training specific settings + task->setMaxIterations(_config.getMaxIterations()); + task->setDisplayInterval(_config.getDisplayInterval()); + task->setSnapshotInterval(_config.getSnapshotInterval()); +} + +void BaristaClassifierTaskSpawner::configureInferenceTask(BaristaInferenceTask* task, QString state) const +{ + configureTask(task); + + // Configure inference specific settings + task->setNetworkStateFile(state); +} diff --git a/Grinder/ml/barista/BaristaClassifierTaskSpawner.h b/Grinder/ml/barista/BaristaClassifierTaskSpawner.h index d21073f..fee72d7 100644 --- a/Grinder/ml/barista/BaristaClassifierTaskSpawner.h +++ b/Grinder/ml/barista/BaristaClassifierTaskSpawner.h @@ -8,14 +8,26 @@ #include "ml/MachineLearningTaskSpawner.h" #include "BaristaClassifierConfiguration.h" +#include "tasks/BaristaTrainingTask.h" +#include "tasks/BaristaInferenceTask.h" namespace grndr { - class BaristaClassifierTaskSpawner : public MachineLearningTaskSpawner<BaristaClassifierConfiguration> + class BaristaClassifierTaskSpawner : public MachineLearningTaskSpawner<BaristaClassifierConfiguration, BaristaTrainingTask, BaristaInferenceTask> { public: - using MachineLearningTaskSpawner<BaristaClassifierConfiguration>::MachineLearningTaskSpawner; + using MachineLearningTaskSpawner<BaristaClassifierConfiguration, BaristaTrainingTask, BaristaInferenceTask>::MachineLearningTaskSpawner; + + protected: + virtual void configureTrainingTask(BaristaTrainingTask* task, QString state) const override; + virtual void configureInferenceTask(BaristaInferenceTask* task, QString state) const override; + + private: + template<typename TaskType> + void configureTask(TaskType* task) const; }; } +#include "BaristaClassifierTaskSpawner.impl.h" + #endif diff --git a/Grinder/ml/barista/BaristaClassifierTaskSpawner.impl.h b/Grinder/ml/barista/BaristaClassifierTaskSpawner.impl.h new file mode 100644 index 0000000..57907f2 --- /dev/null +++ b/Grinder/ml/barista/BaristaClassifierTaskSpawner.impl.h @@ -0,0 +1,19 @@ +/****************************************************************************** + * File: BaristaClassifierTaskSpawner.impl.h + * Date: 14.8.2019 + *****************************************************************************/ + +#include "Grinder.h" +#include "BaristaClassifierTaskSpawner.h" + +template<typename TaskType> +void BaristaClassifierTaskSpawner::configureTask(TaskType* task) const +{ + // Apply general settings to the task + task->setBaristaPort(_config.getBaristaPort()); + task->setLibraryPath(_config.getLibraryPath()); + + task->setNetwork(_config.network()); + task->setOutputDirectory(_config.getOutputDirectory()); + task->setRemoteDirectory(_config.getRemoteDirectory()); +} diff --git a/Grinder/ml/barista/blocks/BaristaClassifierBlock.cpp b/Grinder/ml/barista/blocks/BaristaClassifierBlock.cpp index 948e374..ce428c5 100644 --- a/Grinder/ml/barista/blocks/BaristaClassifierBlock.cpp +++ b/Grinder/ml/barista/blocks/BaristaClassifierBlock.cpp @@ -9,14 +9,14 @@ const BlockCategory BaristaClassifierBlock::category_value = BlockCategory::MachineLearning; const BlockType BaristaClassifierBlock::type_value = BlockType::BaristaClassifier; -BaristaClassifierBlock::BaristaClassifierBlock(Pipeline* pipeline, QString name) : MachineLearningBlock(pipeline, type_value, category_value, name) +BaristaClassifierBlock::BaristaClassifierBlock(Pipeline* pipeline, QString name) : MachineLearningMethodBlock(pipeline, type_value, category_value, name) { } void BaristaClassifierBlock::createProperties() { - MachineLearningBlock::createProperties(); + MachineLearningMethodBlock::createProperties(); setPropertyGroup("General"); @@ -55,7 +55,7 @@ void BaristaClassifierBlock::createProperties() bool BaristaClassifierBlock::updateProperties(PropertyBase* updatedProp) { - bool updated = MachineLearningBlock::updateProperties(updatedProp); + bool updated = MachineLearningMethodBlock::updateProperties(updatedProp); // If the output directory has been modified, update the currently selected state if (updatedProp == _outputDirectory.get()) diff --git a/Grinder/ml/barista/blocks/BaristaClassifierBlock.h b/Grinder/ml/barista/blocks/BaristaClassifierBlock.h index 97e40b7..50c4307 100644 --- a/Grinder/ml/barista/blocks/BaristaClassifierBlock.h +++ b/Grinder/ml/barista/blocks/BaristaClassifierBlock.h @@ -6,13 +6,13 @@ #ifndef BARISTACLASSIFIERBLOCK_H #define BARISTACLASSIFIERBLOCK_H -#include "ml/blocks/MachineLearningBlock.h" +#include "ml/blocks/MachineLearningMethodBlock.h" #include "ml/barista/BaristaClassifierMethod.h" #include "ml/barista/properties/BaristaNetworkProperty.h" namespace grndr { - class BaristaClassifierBlock : public MachineLearningBlock<BaristaClassifierMethod> + class BaristaClassifierBlock : public MachineLearningMethodBlock<BaristaClassifierMethod> { Q_OBJECT diff --git a/Grinder/ml/blocks/MachineLearningBlock.cpp b/Grinder/ml/blocks/MachineLearningBlock.cpp new file mode 100644 index 0000000..1841dda --- /dev/null +++ b/Grinder/ml/blocks/MachineLearningBlock.cpp @@ -0,0 +1,19 @@ +/****************************************************************************** + * File: MachineLearningBlock.cpp + * Date: 13.8.2019 + *****************************************************************************/ + +#include "Grinder.h" +#include "MachineLearningBlock.h" + +void MachineLearningBlock::createPorts() +{ + DataDescriptors methodPortDataDescs = {DataDescriptor::customDescriptor("Machine learning method", DataType::MachineLearningMethod)}; + _methodPort = createPort(PortType::Method, Port::Direction::In, methodPortDataDescs, "Method"); + + DataDescriptors statePortDataDescs = {DataDescriptor::customDescriptor("Model state", DataType::MachineLearningState)}; + _statePort = createPort(PortType::State, Port::Direction::In, statePortDataDescs, "State"); + + DataDescriptors inPortDataDescs = {DataDescriptor::imageDescriptor(true, DataDescriptor::ValueType::Any), DataDescriptor::imageDescriptor(false, DataDescriptor::ValueType::Any)}; + _inPort = createPort(PortType::ImageIn, Port::Direction::In, inPortDataDescs, "In"); +} diff --git a/Grinder/ml/blocks/MachineLearningBlock.h b/Grinder/ml/blocks/MachineLearningBlock.h index 5fd5480..e790f04 100644 --- a/Grinder/ml/blocks/MachineLearningBlock.h +++ b/Grinder/ml/blocks/MachineLearningBlock.h @@ -7,64 +7,32 @@ #define MACHINELEARNINGBLOCK_H #include "pipeline/Block.h" -#include "ml/properties/MachineLearningStateProperty.h" namespace grndr { - class MachineLearningMethodBase; - - template<typename MethodType> class MachineLearningBlock : public Block { - static_assert(std::is_base_of<MachineLearningMethodBase, MethodType>::value, "MethodType must be derived from MachineLearningMethodBase"); - - public: - using method_type = MethodType; - - public: - MachineLearningBlock(Pipeline* pipeline, BlockType type, BlockCategory category, QString name = ""); + Q_OBJECT public: - virtual void initBlock() override; + using Block::Block; public: - method_type& method() { return _method; } - const method_type& method() const { return _method; } - - public: - virtual std::unique_ptr<ProcessorBase> createProcessor() const override; - - public: - auto state() { return dynamic_cast<MachineLearningStateProperty*>(_state.get()); } - auto state() const { return dynamic_cast<const MachineLearningStateProperty*>(_state.get()); } - Port* methodPort() { return _methodPort.get(); } const Port* methodPort() const { return _methodPort.get(); } Port* statePort() { return _statePort.get(); } const Port* statePort() const { return _statePort.get(); } + Port* inPort() { return _inPort.get(); } + const Port* inPort() const { return _inPort.get(); } protected: - virtual void createProperties() override; - virtual bool updateProperties(PropertyBase* updatedProp = nullptr) override; virtual void createPorts() override; - protected: - virtual void updateConfiguration() = 0; - - protected: - bool checkStateAvailability(); - - protected: - method_type _method; - - protected: - std::shared_ptr<PropertyBase> _state; - + private: std::shared_ptr<Port> _methodPort; std::shared_ptr<Port> _statePort; + std::shared_ptr<Port> _inPort; }; } -#include "MachineLearningBlock.impl.h" - #endif diff --git a/Grinder/ml/blocks/MachineLearningMethodBlock.h b/Grinder/ml/blocks/MachineLearningMethodBlock.h new file mode 100644 index 0000000..9cf873b --- /dev/null +++ b/Grinder/ml/blocks/MachineLearningMethodBlock.h @@ -0,0 +1,70 @@ +/****************************************************************************** + * File: MachineLearningMethodBlock.h + * Date: 13.8.2019 + *****************************************************************************/ + +#ifndef MACHINELEARNINGMETHODBLOCK_H +#define MACHINELEARNINGMETHODBLOCK_H + +#include "pipeline/Block.h" +#include "ml/properties/MachineLearningStateProperty.h" + +namespace grndr +{ + class MachineLearningMethodBase; + + template<typename MethodType> + class MachineLearningMethodBlock : public Block + { + static_assert(std::is_base_of<MachineLearningMethodBase, MethodType>::value, "MethodType must be derived from MachineLearningMethodBase"); + + public: + using method_type = MethodType; + + public: + MachineLearningMethodBlock(Pipeline* pipeline, BlockType type, BlockCategory category, QString name = ""); + + public: + virtual void initBlock() override; + + public: + method_type& method() { return _method; } + const method_type& method() const { return _method; } + + public: + virtual std::unique_ptr<ProcessorBase> createProcessor() const override; + + public: + auto state() { return dynamic_cast<MachineLearningStateProperty*>(_state.get()); } + auto state() const { return dynamic_cast<const MachineLearningStateProperty*>(_state.get()); } + + Port* methodPort() { return _methodPort.get(); } + const Port* methodPort() const { return _methodPort.get(); } + Port* statePort() { return _statePort.get(); } + const Port* statePort() const { return _statePort.get(); } + + protected: + virtual void createProperties() override; + virtual bool updateProperties(PropertyBase* updatedProp = nullptr) override; + virtual void createPorts() override; + + protected: + virtual void updateConfiguration() = 0; + + protected: + bool checkStateAvailability(); + + protected: + method_type _method; + + protected: + std::shared_ptr<PropertyBase> _state; + + std::shared_ptr<Port> _methodPort; + std::shared_ptr<Port> _statePort; + }; +} + +#include "MachineLearningMethodBlock.impl.h" + +#endif diff --git a/Grinder/ml/blocks/MachineLearningBlock.impl.h b/Grinder/ml/blocks/MachineLearningMethodBlock.impl.h similarity index 62% rename from Grinder/ml/blocks/MachineLearningBlock.impl.h rename to Grinder/ml/blocks/MachineLearningMethodBlock.impl.h index c9e17bd..fef7a18 100644 --- a/Grinder/ml/blocks/MachineLearningBlock.impl.h +++ b/Grinder/ml/blocks/MachineLearningMethodBlock.impl.h @@ -1,20 +1,20 @@ /****************************************************************************** - * File: MachineLearningBlock.impl.h + * File: MachineLearningMethodBlock.impl.h * Date: 13.8.2019 *****************************************************************************/ #include "Grinder.h" -#include "MachineLearningBlock.h" -#include "ml/processors/MachineLearningProcessor.h" +#include "MachineLearningMethodBlock.h" +#include "ml/processors/MachineLearningMethodProcessor.h" template<typename MethodType> -MachineLearningBlock<MethodType>::MachineLearningBlock(Pipeline* pipeline, BlockType type, BlockCategory category, QString name) : Block(pipeline, type, category, name) +MachineLearningMethodBlock<MethodType>::MachineLearningMethodBlock(Pipeline* pipeline, BlockType type, BlockCategory category, QString name) : Block(pipeline, type, category, name) { _bypassPossible = false; } template<typename MethodType> -void MachineLearningBlock<MethodType>::initBlock() +void MachineLearningMethodBlock<MethodType>::initBlock() { Block::initBlock(); @@ -22,13 +22,13 @@ void MachineLearningBlock<MethodType>::initBlock() } template<typename MethodType> -std::unique_ptr<ProcessorBase> MachineLearningBlock<MethodType>::createProcessor() const +std::unique_ptr<ProcessorBase> MachineLearningMethodBlock<MethodType>::createProcessor() const { - return std::make_unique<MachineLearningProcessor<method_type>>(this); + return std::make_unique<MachineLearningMethodProcessor<method_type>>(this); } template<typename MethodType> -void MachineLearningBlock<MethodType>::createProperties() +void MachineLearningMethodBlock<MethodType>::createProperties() { Block::createProperties(); @@ -41,7 +41,7 @@ void MachineLearningBlock<MethodType>::createProperties() } template<typename MethodType> -bool MachineLearningBlock<MethodType>::updateProperties(PropertyBase* updatedProp) +bool MachineLearningMethodBlock<MethodType>::updateProperties(PropertyBase* updatedProp) { bool updated = Block::updateProperties(updatedProp); @@ -51,7 +51,7 @@ bool MachineLearningBlock<MethodType>::updateProperties(PropertyBase* updatedPro } template<typename MethodType> -void MachineLearningBlock<MethodType>::createPorts() +void MachineLearningMethodBlock<MethodType>::createPorts() { DataDescriptors methodPortDataDescs = {DataDescriptor::customDescriptor("Machine learning method", DataType::MachineLearningMethod)}; _methodPort = createPort(PortType::Method, Port::Direction::Out, methodPortDataDescs, "Method"); @@ -61,7 +61,7 @@ void MachineLearningBlock<MethodType>::createPorts() } template<typename MethodType> -bool MachineLearningBlock<MethodType>::checkStateAvailability() +bool MachineLearningMethodBlock<MethodType>::checkStateAvailability() { auto states = _method.getAvailableStates(); diff --git a/Grinder/ml/blocks/TrainingBlock.cpp b/Grinder/ml/blocks/TrainingBlock.cpp new file mode 100644 index 0000000..8f3af15 --- /dev/null +++ b/Grinder/ml/blocks/TrainingBlock.cpp @@ -0,0 +1,31 @@ +/****************************************************************************** + * File: TrainingBlock.cpp + * Date: 13.8.2019 + *****************************************************************************/ + +#include "Grinder.h" +#include "TrainingBlock.h" +#include "ml/processors/TrainingProcessor.h" + +const BlockType TrainingBlock::type_value = BlockType::Training; +const BlockCategory TrainingBlock::category_value = BlockCategory::MachineLearning; + +TrainingBlock::TrainingBlock(Pipeline* pipeline, QString name) : MachineLearningBlock(pipeline, type_value, category_value, name) +{ + +} + +std::unique_ptr<ProcessorBase> TrainingBlock::createProcessor() const +{ + return std::make_unique<TrainingProcessor>(this); +} + +void TrainingBlock::createPorts() +{ + MachineLearningBlock::createPorts(); + + _tagsBitmapPort = createPort(PortType::ImageTagsBitmap, Port::Direction::In, {DataDescriptor::imageDescriptor(false, DataDescriptor::ValueType::UInt32)}, "Tags bitmap"); + + DataDescriptors trainedStatePortDataDescs = {DataDescriptor::customDescriptor("Trained state", DataType::MachineLearningState)}; + _trainedStatePort = createPort(PortType::TrainedState, Port::Direction::Out, trainedStatePortDataDescs, "State"); +} diff --git a/Grinder/ml/blocks/TrainingBlock.h b/Grinder/ml/blocks/TrainingBlock.h new file mode 100644 index 0000000..05e1368 --- /dev/null +++ b/Grinder/ml/blocks/TrainingBlock.h @@ -0,0 +1,42 @@ +/****************************************************************************** + * File: TrainingBlock.h + * Date: 13.8.2019 + *****************************************************************************/ + +#ifndef TRAININGBLOCK_H +#define TRAININGBLOCK_H + +#include "MachineLearningBlock.h" + +namespace grndr +{ + class TrainingBlock : public MachineLearningBlock + { + Q_OBJECT + + public: + static const BlockType type_value; + static const BlockCategory category_value; + + public: + TrainingBlock(Pipeline* pipeline, QString name = ""); + + public: + virtual std::unique_ptr<ProcessorBase> createProcessor() const override; + + public: + Port* tagsBitmapPort() { return _tagsBitmapPort.get(); } + const Port* tagsBitmapPort() const { return _tagsBitmapPort.get(); } + Port* trainedStatePort() { return _trainedStatePort.get(); } + const Port* trainedStatePort() const { return _trainedStatePort.get(); } + + protected: + virtual void createPorts() override; + + private: + std::shared_ptr<Port> _tagsBitmapPort; + std::shared_ptr<Port> _trainedStatePort; + }; +} + +#endif diff --git a/Grinder/ml/processors/MachineLearningMethodProcessor.h b/Grinder/ml/processors/MachineLearningMethodProcessor.h new file mode 100644 index 0000000..0a14e0a --- /dev/null +++ b/Grinder/ml/processors/MachineLearningMethodProcessor.h @@ -0,0 +1,38 @@ +/****************************************************************************** + * File: MachineLearningMethodProcessor.h + * Date: 13.8.2019 + *****************************************************************************/ + +#ifndef MACHINELEARNINGMETHODPROCESSOR_H +#define MACHINELEARNINGMETHODPROCESSOR_H + +#include "engine/Processor.h" +#include "ml/blocks/MachineLearningMethodBlock.h" + +namespace grndr +{ + class MachineLearningMethodBase; + + template<typename MethodType> + class MachineLearningMethodProcessor : public Processor<MachineLearningMethodBlock<MethodType>> + { + static_assert(std::is_base_of<MachineLearningMethodBase, MethodType>::value, "MethodType must be derived from MachineLearningMethodBase"); + + private: + static const char* Data_Value_Method; + static const char* Data_Value_State; + + public: + using method_type = MethodType; + + public: + MachineLearningMethodProcessor(const Block* block); + + public: + virtual void execute(EngineExecutionContext& ctx) override; + }; +} + +#include "MachineLearningMethodProcessor.impl.h" + +#endif diff --git a/Grinder/ml/processors/MachineLearningMethodProcessor.impl.h b/Grinder/ml/processors/MachineLearningMethodProcessor.impl.h new file mode 100644 index 0000000..9c88e62 --- /dev/null +++ b/Grinder/ml/processors/MachineLearningMethodProcessor.impl.h @@ -0,0 +1,35 @@ +/****************************************************************************** + * File: MachineLearningMethodProcessor.impl.h + * Date: 13.8.2019 + *****************************************************************************/ + +#include "Grinder.h" +#include "MachineLearningMethodProcessor.h" +#include "ml/MachineLearningMethodBase.h" + +Q_DECLARE_METATYPE(const MachineLearningMethodBase*) + +template<typename MethodType> +const char* MachineLearningMethodProcessor<MethodType>::Data_Value_Method = "Method"; +template<typename MethodType> +const char* MachineLearningMethodProcessor<MethodType>::Data_Value_State = "State"; + +template<typename MethodType> +MachineLearningMethodProcessor<MethodType>::MachineLearningMethodProcessor(const Block* block) : Processor<MachineLearningMethodBlock<MethodType>>(block) +{ + +} + +template<typename MethodType> +void MachineLearningMethodProcessor<MethodType>::execute(EngineExecutionContext& ctx) +{ + if (ctx.isFirstImage()) + { + // Verify all configuration settings + this->_block->method().config().verifyConfiguration(); + } + + // Store the machine learning method and model state in the context data + ctx.data().set(this->_block->methodPort(), Data_Value_Method, static_cast<const MachineLearningMethodBase*>(&this->_block->method())); + ctx.data().set(this->_block->statePort(), Data_Value_State, this->_block->state()->getValue()); +} diff --git a/Grinder/ml/processors/MachineLearningProcessor.h b/Grinder/ml/processors/MachineLearningProcessor.h index 758aa28..067c813 100644 --- a/Grinder/ml/processors/MachineLearningProcessor.h +++ b/Grinder/ml/processors/MachineLearningProcessor.h @@ -7,29 +7,43 @@ #define MACHINELEARNINGPROCESSOR_H #include "engine/Processor.h" -#include "ml/blocks/MachineLearningBlock.h" namespace grndr { + class MachineLearningBlock; class MachineLearningMethodBase; + class Task; - template<typename MethodType> - class MachineLearningProcessor : public Processor<MachineLearningBlock<MethodType>> + template<typename BlockType> + class MachineLearningProcessor : public Processor<BlockType> { - static_assert(std::is_base_of<MachineLearningMethodBase, MethodType>::value, "MethodType must be derived from MachineLearningMethodBase"); + static_assert(std::is_base_of<MachineLearningBlock, BlockType>::value, "BlockType must be derived from MachineLearningBlock"); - public: + private: static const char* Data_Value_Method; static const char* Data_Value_State; - public: - using method_type = MethodType; + protected: + enum class SpawnType + { + Training, + Inference, + }; public: - MachineLearningProcessor(const Block* block); + using Processor<BlockType>::Processor; public: virtual void execute(EngineExecutionContext& ctx) override; + + protected: + virtual void execute(EngineExecutionContext& ctx, const MachineLearningMethodBase* method, QString state) = 0; + + protected: + std::shared_ptr<Task> spawnTask(SpawnType type, const MachineLearningMethodBase* method, QString state) const; + + private: + QString getSpawnTypeName(SpawnType type) const; }; } diff --git a/Grinder/ml/processors/MachineLearningProcessor.impl.h b/Grinder/ml/processors/MachineLearningProcessor.impl.h index 6cd16e7..d130f02 100644 --- a/Grinder/ml/processors/MachineLearningProcessor.impl.h +++ b/Grinder/ml/processors/MachineLearningProcessor.impl.h @@ -6,33 +6,66 @@ #include "Grinder.h" #include "MachineLearningProcessor.h" #include "ml/MachineLearningMethodBase.h" +#include "ml/MachineLearningTaskSpawnerBase.h" + +template<typename BlockType> +const char* MachineLearningProcessor<BlockType>::Data_Value_Method = "Method"; +template<typename BlockType> +const char* MachineLearningProcessor<BlockType>::Data_Value_State = "State"; Q_DECLARE_METATYPE(const MachineLearningMethodBase*) -template<typename MethodType> -const char* MachineLearningProcessor<MethodType>::Data_Value_Method = "Method"; -template<typename MethodType> -const char* MachineLearningProcessor<MethodType>::Data_Value_State = "State"; +template<typename BlockType> +void MachineLearningProcessor<BlockType>::execute(EngineExecutionContext& ctx) +{ + if (!this->isBlockBypassed()) + { + // Get the used machine learning method and model state + const MachineLearningMethodBase* method = this->template portData<const MachineLearningMethodBase*>(ctx, this->_block->methodPort(), Data_Value_Method); + QString state = this->template portData<QString>(ctx, this->_block->statePort(), Data_Value_State, false); + + execute(ctx, method, state); + } +} -template<typename MethodType> -MachineLearningProcessor<MethodType>::MachineLearningProcessor(const Block* block) : Processor<MachineLearningBlock<MethodType>>(block) +template<typename BlockType> +std::shared_ptr<Task> MachineLearningProcessor<BlockType>::spawnTask(SpawnType type, const MachineLearningMethodBase* method, QString state) const { + if (auto spawner = method->createTaskSpawner()) + { + QString taskName = QString{"%1 %2 (%3)"}.arg(method->getMethodName()).arg(getSpawnTypeName(type)).arg(this->_block->getFormattedName()); + std::shared_ptr<Task> task; + + switch (type) + { + case SpawnType::Training: + task = spawner->spawnTrainingTask(state, taskName); + break; + + case SpawnType::Inference: + task = spawner->spawnInferenceTask(state, taskName); + break; + } + + return task; + } + else + this->throwProcessorException("Unable to create a task spawner"); + return nullptr; } -template<typename MethodType> -void MachineLearningProcessor<MethodType>::execute(EngineExecutionContext& ctx) +template<typename BlockType> +QString MachineLearningProcessor<BlockType>::getSpawnTypeName(SpawnType type) const { - // Only set persistent data on the first image - if (ctx.isFirstImage()) + switch (type) { - // Verify all configuration settings - this->_block->method().config().verifyConfiguration(); + case SpawnType::Training: + return "Training"; - // Store the machine learning method in the context persistent data - ctx.persistentData().set(this->_block->methodPort(), Data_Value_Method, static_cast<const MachineLearningMethodBase*>(&this->_block->method())); + case SpawnType::Inference: + return "Inference"; } - // Store the model state in the context data - ctx.data().set(this->_block->statePort(), Data_Value_State, this->_block->state()->getValue()); + return ""; } diff --git a/Grinder/ml/processors/TrainingProcessor.cpp b/Grinder/ml/processors/TrainingProcessor.cpp new file mode 100644 index 0000000..dd685ed --- /dev/null +++ b/Grinder/ml/processors/TrainingProcessor.cpp @@ -0,0 +1,29 @@ +/****************************************************************************** + * File: TrainingProcessor.cpp + * Date: 13.8.2019 + *****************************************************************************/ + +#include "Grinder.h" +#include "TrainingProcessor.h" +#include "ml/MachineLearningTaskSpawnerBase.h" + +TrainingProcessor::TrainingProcessor(const Block* block) : MachineLearningProcessor(block) +{ + +} + +void TrainingProcessor::execute(EngineExecutionContext& ctx, const MachineLearningMethodBase* method, QString state) +{ + // Training is only executed in batch mode + if (ctx.hasExecutionFlag(Engine::ExecutionFlag::Batch)) + { + // Spawn the training task when the last image is active + if (ctx.isLastImage()) + { + // TODO: Gather images, labels etc. before spawning + spawnTask(SpawnType::Training, method, state); + } + } + else + throwProcessorException("Training is only possible in batch mode; bypass the training block to avoid this warning"); +} diff --git a/Grinder/ml/processors/TrainingProcessor.h b/Grinder/ml/processors/TrainingProcessor.h new file mode 100644 index 0000000..5dde33b --- /dev/null +++ b/Grinder/ml/processors/TrainingProcessor.h @@ -0,0 +1,24 @@ +/****************************************************************************** + * File: TrainingProcessor.h + * Date: 13.8.2019 + *****************************************************************************/ + +#ifndef TRAININGPROCESSOR_H +#define TRAININGPROCESSOR_H + +#include "MachineLearningProcessor.h" +#include "ml/blocks/TrainingBlock.h" + +namespace grndr +{ + class TrainingProcessor : public MachineLearningProcessor<TrainingBlock> + { + public: + TrainingProcessor(const Block* block); + + protected: + virtual void execute(EngineExecutionContext& ctx, const MachineLearningMethodBase* method, QString state) override; + }; +} + +#endif diff --git a/Grinder/pipeline/BlockCatalog.cpp b/Grinder/pipeline/BlockCatalog.cpp index 917125f..f5cc6e2 100644 --- a/Grinder/pipeline/BlockCatalog.cpp +++ b/Grinder/pipeline/BlockCatalog.cpp @@ -33,6 +33,7 @@ #include "blocks/ResizeBlock.h" #include "blocks/SaveImageBlock.h" +#include "ml/blocks/TrainingBlock.h" #include "ml/barista/blocks/BaristaClassifierBlock.h" #define REGISTER_BLOCK_TYPE(cls, desc) registerBlockType(cls::type_value, cls::category_value, [](Pipeline* pipeline, QString name) { return std::make_unique<cls>(pipeline, name); }, desc) @@ -147,5 +148,6 @@ void BlockCatalog::registerStandardBlocks() REGISTER_BLOCK_TYPE(ResizeBlock, "Resizes its input."); REGISTER_BLOCK_TYPE(SaveImageBlock, "Saves its input to an image file."); + REGISTER_BLOCK_TYPE(TrainingBlock, "Performs training on labeled images using a machine learning method."); REGISTER_BLOCK_TYPE(BaristaClassifierBlock, "Provides the Barista machine learning classifier."); } diff --git a/Grinder/pipeline/BlockType.cpp b/Grinder/pipeline/BlockType.cpp index 1a75790..0c357a5 100644 --- a/Grinder/pipeline/BlockType.cpp +++ b/Grinder/pipeline/BlockType.cpp @@ -43,6 +43,8 @@ const char* BlockType::Edges = "Edges"; const char* BlockType::Watershed = "Watershed"; const char* BlockType::GrabCut = "GrabCut"; +const char* BlockType::Training = "Training"; +const char* BlockType::Inference = "Inference"; const char* BlockType::BaristaClassifier = "BaristaClassifier"; const char* BlockType::ImageTags = "ImageTags"; diff --git a/Grinder/pipeline/BlockType.h b/Grinder/pipeline/BlockType.h index 7d80266..c1a423e 100644 --- a/Grinder/pipeline/BlockType.h +++ b/Grinder/pipeline/BlockType.h @@ -50,6 +50,8 @@ namespace grndr static const char* Watershed; static const char* GrabCut; + static const char* Training; + static const char* Inference; static const char* BaristaClassifier; static const char* ImageTags; diff --git a/Grinder/pipeline/PortType.cpp b/Grinder/pipeline/PortType.cpp index 11cbbd2..880a6a5 100644 --- a/Grinder/pipeline/PortType.cpp +++ b/Grinder/pipeline/PortType.cpp @@ -30,3 +30,4 @@ const char* PortType::Successor = "Successor"; const char* PortType::Method = "Method"; const char* PortType::State = "State"; +const char* PortType::TrainedState = "TrainedState"; diff --git a/Grinder/pipeline/PortType.h b/Grinder/pipeline/PortType.h index 72101b9..07c92f3 100644 --- a/Grinder/pipeline/PortType.h +++ b/Grinder/pipeline/PortType.h @@ -37,6 +37,7 @@ namespace grndr static const char* Method; static const char* State; + static const char* TrainedState; public: using QString::QString; -- GitLab