diff --git a/Grinder/Grinder.pro b/Grinder/Grinder.pro index 35f0ec6c9b34e62107e26900598aa6792cd6bf27..338272ee6e8c756636bb8642aadbf6028608fb52 100644 --- a/Grinder/Grinder.pro +++ b/Grinder/Grinder.pro @@ -441,7 +441,6 @@ SOURCES += \ ui/properties/editors/PathPropertyEditor.cpp \ ui/dlg/BrowseDialog.cpp \ engine/data/DataType.cpp \ - ml/MachineLearningMethod.cpp \ ml/MachineLearningMethodBase.cpp \ ml/MachineLearningConfiguration.cpp \ ml/barista/BaristaClassifierConfiguration.cpp \ @@ -453,7 +452,10 @@ SOURCES += \ ml/barista/properties/BaristaNetworkProperty.cpp \ ui/barista/editors/BaristaNetworkPropertyEditor.cpp \ ml/properties/MachineLearningStateProperty.cpp \ - ui/ml/editors/MachineLearningStatePropertyEditor.cpp + ui/ml/editors/MachineLearningStatePropertyEditor.cpp \ + ml/blocks/MachineLearningBlock.cpp \ + ml/blocks/TrainingBlock.cpp \ + ml/processors/TrainingProcessor.cpp HEADERS += \ ui/mainwnd/GrinderWindow.h \ @@ -968,15 +970,21 @@ HEADERS += \ ml/MachineLearningTaskSpawnerBase.h \ ml/MachineLearningTaskSpawner.h \ ml/MachineLearningTaskSpawner.impl.h \ - ml/blocks/MachineLearningBlock.h \ - ml/blocks/MachineLearningBlock.impl.h \ - ml/processors/MachineLearningProcessor.h \ - ml/processors/MachineLearningProcessor.impl.h \ ml/barista/blocks/BaristaClassifierBlock.h \ ml/barista/properties/BaristaNetworkProperty.h \ ui/barista/editors/BaristaNetworkPropertyEditor.h \ ml/properties/MachineLearningStateProperty.h \ - ui/ml/editors/MachineLearningStatePropertyEditor.h + ui/ml/editors/MachineLearningStatePropertyEditor.h \ + ml/blocks/MachineLearningMethodBlock.h \ + ml/blocks/MachineLearningMethodBlock.impl.h \ + ml/blocks/MachineLearningBlock.h \ + ml/blocks/TrainingBlock.h \ + ml/processors/TrainingProcessor.h \ + ml/processors/MachineLearningMethodProcessor.h \ + ml/processors/MachineLearningMethodProcessor.impl.h \ + ml/processors/MachineLearningProcessor.h \ + ml/processors/MachineLearningProcessor.impl.h \ + ml/barista/BaristaClassifierTaskSpawner.impl.h FORMS += \ ui/mainwnd/GrinderWindow.ui \ diff --git a/Grinder/Version.h b/Grinder/Version.h index 62d30ff5dbc8bde79310a984c0ff50407b7dcb0f..f4f176dd4e9425e5d7053624ce1a6cc91eb968db 100644 --- a/Grinder/Version.h +++ b/Grinder/Version.h @@ -10,14 +10,14 @@ #define GRNDR_INFO_TITLE "Grinder" #define GRNDR_INFO_COPYRIGHT "Copyright (c) WWU Muenster" -#define GRNDR_INFO_DATE "13.8.2019" +#define GRNDR_INFO_DATE "14.8.2019" #define GRNDR_INFO_COMPANY "WWU Muenster" #define GRNDR_INFO_WEBSITE "http://www.uni-muenster.de" #define GRNDR_VERSION_MAJOR 0 #define GRNDR_VERSION_MINOR 15 #define GRNDR_VERSION_REVISION 0 -#define GRNDR_VERSION_BUILD 374 +#define GRNDR_VERSION_BUILD 375 namespace grndr { diff --git a/Grinder/controller/TaskController.cpp b/Grinder/controller/TaskController.cpp index d659d7e27eaa752f44244c5f848fa6a0f88d4475..ccd18930e6f3c6f645a783dba82b8fd4b92e382c 100644 --- a/Grinder/controller/TaskController.cpp +++ b/Grinder/controller/TaskController.cpp @@ -1,164 +1,170 @@ -/****************************************************************************** - * File: TaskController.cpp - * Date: 01.11.2018 - *****************************************************************************/ - -#include "Grinder.h" -#include "TaskController.h" -#include "task/TaskPool.h" -#include "task/TaskExceptions.h" -#include "ui/task/TaskPoolWidget.h" -#include "ui/task/ConfigureTaskDialog.h" -#include "util/StringUtils.h" - -TaskController::TaskController(TaskPool* taskPool) : GenericController("Task pool"), - _taskPool{taskPool} -{ - if (!taskPool) - throw std::invalid_argument{_EXCPT("taskPool may not be null")}; - - // Listen for task pool events - connect(_taskPool, &TaskPool::taskCreated, this, &TaskController::taskCreated); - connect(_taskPool, &TaskPool::taskRemoved, this, &TaskController::taskRemoved); - - // Periodically update all tasks - _updateTimer.setInterval(1000); - _updateTimer.start(); - - connect(&_updateTimer, &QTimer::timeout, this, &TaskController::updateTasks); -} - -void TaskController::assignUiComponents(TaskPoolWidget* taskPoolWidget) -{ - if (!taskPoolWidget) - throw std::invalid_argument{_EXCPT("taskPoolWidget may not be null")}; - - _taskPoolWidget = taskPoolWidget; -} - -std::shared_ptr<Task> TaskController::createTask(TaskType type, QString name) const -{ - return callControllerFunction("Creating a new task", [this](TaskType type, QString name) { - return _taskPool->createTask(type, name); - }, type, name); -} - -void TaskController::removeTask(const Task* task) const -{ - if (task) - { - callControllerFunction("Removing a task", [this](const Task* task) { - _taskPool->removeTask(task); - return true; - }, task); - } -} - -void TaskController::removeAllTasks() -{ - if (_taskPool) - { - auto tasks = _taskPool->tasks(); - - for (const auto& task : tasks) - removeTask(task.get()); - } -} - -void TaskController::startTask(Task* task) const -{ - callControllerFunction("Starting a task", [](Task* task) { - task->startTask(); - return true; - }, task); -} - -void TaskController::pauseTask(Task* task, bool pause) const -{ - callControllerFunction("Starting a task", [pause](Task* task) { - task->pauseTask(pause); - return true; - }, task); -} - -void TaskController::refreshTask(Task* task) const -{ - callControllerFunction("Refreshing a task", [](Task* task) { - task->refreshTask(); - return true; - }, task); -} - -void TaskController::stopTask(Task* task) const -{ - callControllerFunction("Stopping a task", [](Task* task) { - task->stopTask(); - return true; - }, task); -} - -void TaskController::newTask(TaskType type) const -{ - callControllerFunction("Adding a new task", [this](TaskType type) { - QString newTaskName = StringUtils::generateUniqueItemName(_taskPool->tasks(), type + " task", &Task::getName); - - if (auto task = _taskPool->createTask(type, newTaskName, false)) - { - ConfigureTaskDialog dlg{task.get(), true}; - - if (dlg.exec() == QDialog::Accepted) - _taskPool->addTask(task); - } - - return true; - }, type); -} - -void TaskController::configureTask(Task* task) const -{ - callControllerFunction("Configuring a task", [](Task* task) { - if (!task->isRunning()) - { - ConfigureTaskDialog dlg{task, false}; - dlg.exec(); - } - - return true; - }, task); -} - -void TaskController::validateTaskName(QString name, const Task* task) const -{ - auto taskPool = task ? task->taskPool() : _taskPool; - - // Name must be non-empty and unique - if (!name.isEmpty()) - { - auto existingTask = taskPool->tasks().selectByName(name).get(); - - if (existingTask && existingTask != task) - throw TaskPoolException{taskPool, _EXCPT(QString{"A task with the name '%1' already exists"}.arg(name))}; - } - else - throw TaskPoolException{taskPool, _EXCPT("The task name may not be empty")}; -} - -void TaskController::taskCreated(const std::shared_ptr<Task>& task) const -{ - if (_taskPoolWidget) - _taskPoolWidget->addTask(task); -} - -void TaskController::taskRemoved(const std::shared_ptr<Task>& task) const -{ - if (_taskPoolWidget) - _taskPoolWidget->removeTask(task); -} - -void TaskController::updateTasks() -{ - callControllerFunction("Updating tasks", [this]() { - _taskPool->updateTasks(); - return true; - }); -} +/****************************************************************************** + * File: TaskController.cpp + * Date: 01.11.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "TaskController.h" +#include "task/TaskPool.h" +#include "task/TaskExceptions.h" +#include "ui/task/TaskPoolWidget.h" +#include "ui/task/ConfigureTaskDialog.h" +#include "util/StringUtils.h" + +TaskController::TaskController(TaskPool* taskPool) : GenericController("Task pool"), + _taskPool{taskPool} +{ + if (!taskPool) + throw std::invalid_argument{_EXCPT("taskPool may not be null")}; + + // Listen for task pool events + connect(_taskPool, &TaskPool::taskCreated, this, &TaskController::taskCreated); + connect(_taskPool, &TaskPool::taskRemoved, this, &TaskController::taskRemoved); + + // Periodically update all tasks + _updateTimer.setInterval(1000); + _updateTimer.start(); + + connect(&_updateTimer, &QTimer::timeout, this, &TaskController::updateTasks); +} + +void TaskController::assignUiComponents(TaskPoolWidget* taskPoolWidget) +{ + if (!taskPoolWidget) + throw std::invalid_argument{_EXCPT("taskPoolWidget may not be null")}; + + _taskPoolWidget = taskPoolWidget; +} + +std::shared_ptr<Task> TaskController::createTask(TaskType type, QString name, bool addToPool) const +{ + return callControllerFunction("Creating a new task", [this](TaskType type, QString name, bool addToPool) { + return _taskPool->createTask(type, name, addToPool); + }, type, name, addToPool); +} + +void TaskController::addTask(std::shared_ptr<Task>& task) +{ + if (task) + _taskPool->addTask(task); +} + +void TaskController::removeTask(const Task* task) const +{ + if (task) + { + callControllerFunction("Removing a task", [this](const Task* task) { + _taskPool->removeTask(task); + return true; + }, task); + } +} + +void TaskController::removeAllTasks() +{ + if (_taskPool) + { + auto tasks = _taskPool->tasks(); + + for (const auto& task : tasks) + removeTask(task.get()); + } +} + +void TaskController::startTask(Task* task) const +{ + callControllerFunction("Starting a task", [](Task* task) { + task->startTask(); + return true; + }, task); +} + +void TaskController::pauseTask(Task* task, bool pause) const +{ + callControllerFunction("Starting a task", [pause](Task* task) { + task->pauseTask(pause); + return true; + }, task); +} + +void TaskController::refreshTask(Task* task) const +{ + callControllerFunction("Refreshing a task", [](Task* task) { + task->refreshTask(); + return true; + }, task); +} + +void TaskController::stopTask(Task* task) const +{ + callControllerFunction("Stopping a task", [](Task* task) { + task->stopTask(); + return true; + }, task); +} + +void TaskController::newTask(TaskType type) const +{ + callControllerFunction("Adding a new task", [this](TaskType type) { + QString newTaskName = StringUtils::generateUniqueItemName(_taskPool->tasks(), type + " task", &Task::getName); + + if (auto task = _taskPool->createTask(type, newTaskName, false)) + { + ConfigureTaskDialog dlg{task.get(), true}; + + if (dlg.exec() == QDialog::Accepted) + _taskPool->addTask(task); + } + + return true; + }, type); +} + +void TaskController::configureTask(Task* task) const +{ + callControllerFunction("Configuring a task", [](Task* task) { + if (!task->isRunning()) + { + ConfigureTaskDialog dlg{task, false}; + dlg.exec(); + } + + return true; + }, task); +} + +void TaskController::validateTaskName(QString name, const Task* task) const +{ + auto taskPool = task ? task->taskPool() : _taskPool; + + // Name must be non-empty and unique + if (!name.isEmpty()) + { + auto existingTask = taskPool->tasks().selectByName(name).get(); + + if (existingTask && existingTask != task) + throw TaskPoolException{taskPool, _EXCPT(QString{"A task with the name '%1' already exists"}.arg(name))}; + } + else + throw TaskPoolException{taskPool, _EXCPT("The task name may not be empty")}; +} + +void TaskController::taskCreated(const std::shared_ptr<Task>& task) const +{ + if (_taskPoolWidget) + _taskPoolWidget->addTask(task); +} + +void TaskController::taskRemoved(const std::shared_ptr<Task>& task) const +{ + if (_taskPoolWidget) + _taskPoolWidget->removeTask(task); +} + +void TaskController::updateTasks() +{ + callControllerFunction("Updating tasks", [this]() { + _taskPool->updateTasks(); + return true; + }); +} diff --git a/Grinder/controller/TaskController.h b/Grinder/controller/TaskController.h index af95bce0a00ad38b2715888799539fdabaadc184..f733ac88ed069211150490a01f39c3743c9c1ee2 100644 --- a/Grinder/controller/TaskController.h +++ b/Grinder/controller/TaskController.h @@ -1,60 +1,61 @@ -/****************************************************************************** - * File: TaskController.h - * Date: 01.11.2018 - *****************************************************************************/ - -#ifndef TASKCONTROLLER_H -#define TASKCONTROLLER_H - -#include "GenericController.h" -#include "task/TaskType.h" - -namespace grndr -{ - class TaskPool; - class Task; - class TaskPoolWidget; - - class TaskController : public GenericController - { - Q_OBJECT - - public: - TaskController(TaskPool* taskPool); - - public: - void assignUiComponents(TaskPoolWidget* taskPoolWidget); - - public: - std::shared_ptr<Task> createTask(TaskType type, QString name = "") const; - void removeTask(const Task* task) const; - void removeAllTasks(); - - void startTask(Task* task) const; - void pauseTask(Task* task, bool pause = true) const; - void refreshTask(Task* task) const; - void stopTask(Task* task) const; - - void newTask(TaskType type) const; - void configureTask(Task* task) const; - - public: - void validateTaskName(QString name, const Task* task = nullptr) const; - - private slots: - void taskCreated(const std::shared_ptr<Task>& task) const; - void taskRemoved(const std::shared_ptr<Task>& task) const; - - void updateTasks(); - - private: - TaskPool* _taskPool{nullptr}; - - TaskPoolWidget* _taskPoolWidget{nullptr}; - - private: - QTimer _updateTimer; - }; -} - -#endif +/****************************************************************************** + * File: TaskController.h + * Date: 01.11.2018 + *****************************************************************************/ + +#ifndef TASKCONTROLLER_H +#define TASKCONTROLLER_H + +#include "GenericController.h" +#include "task/TaskType.h" + +namespace grndr +{ + class TaskPool; + class Task; + class TaskPoolWidget; + + class TaskController : public GenericController + { + Q_OBJECT + + public: + TaskController(TaskPool* taskPool); + + public: + void assignUiComponents(TaskPoolWidget* taskPoolWidget); + + public: + std::shared_ptr<Task> createTask(TaskType type, QString name = "", bool addToPool = true) const; + void addTask(std::shared_ptr<Task>& task); + void removeTask(const Task* task) const; + void removeAllTasks(); + + void startTask(Task* task) const; + void pauseTask(Task* task, bool pause = true) const; + void refreshTask(Task* task) const; + void stopTask(Task* task) const; + + void newTask(TaskType type) const; + void configureTask(Task* task) const; + + public: + void validateTaskName(QString name, const Task* task = nullptr) const; + + private slots: + void taskCreated(const std::shared_ptr<Task>& task) const; + void taskRemoved(const std::shared_ptr<Task>& task) const; + + void updateTasks(); + + private: + TaskPool* _taskPool{nullptr}; + + TaskPoolWidget* _taskPoolWidget{nullptr}; + + private: + QTimer _updateTimer; + }; +} + +#endif diff --git a/Grinder/engine/ProcessorBase.cpp b/Grinder/engine/ProcessorBase.cpp index 863240cc99df1397adb15eb5834c2f7f833d46fc..4bac0ec3e07f7830fb6a34f3bf8546145cd0bb0b 100644 --- a/Grinder/engine/ProcessorBase.cpp +++ b/Grinder/engine/ProcessorBase.cpp @@ -56,10 +56,18 @@ DataDescriptor ProcessorBase::getBestDataDescriptor(const DataDescriptor& dataDe return DataDescriptor{}; } -DataBlob* ProcessorBase::portData(EngineExecutionContext& ctx, const Port* port, const Connection* connection, bool required, bool convert) const +DataBlob* ProcessorBase::portData(EngineExecutionContext& ctx, const Port* port, const Connection* connection, bool required, bool convert, bool persistentData) const { auto dataPort = resolveDataPort(port, connection, required); - DataBlob* data = dataPort ? ctx.data().get(dataPort) : nullptr; + DataBlob* data = nullptr; + + if (dataPort) + { + if (persistentData) + data = ctx.persistentData().get(dataPort); + else + data = ctx.data().get(dataPort); + } if (!data && required) throwProcessorException(QString{"No data could be retrieved for port '%1'"}.arg(port->getName())); @@ -100,7 +108,7 @@ cv::Mat ProcessorBase::getBypassedImage(const DataBlob* dataBlob, bool outputIsG void ProcessorBase::throwProcessorException(QString what) const { - throw ProcessorException{this, _EXCPT(QString{"%1: %2"}.arg(_block->getName()).arg(what))}; + throw ProcessorException{this, _EXCPT(what)}; } const Port* ProcessorBase::resolveDataPort(const Port* port, const Connection* connection, bool required) const diff --git a/Grinder/engine/ProcessorBase.h b/Grinder/engine/ProcessorBase.h index 526c0c29a8f92ae963a678a5fe717eda498680ad..dde06a91e06b051cca43c393065d0414d586c008 100644 --- a/Grinder/engine/ProcessorBase.h +++ b/Grinder/engine/ProcessorBase.h @@ -1,58 +1,61 @@ -/****************************************************************************** - * File: ProcessorBase.h - * Date: 21.2.2018 - *****************************************************************************/ - -#ifndef PROCESSORBASE_H -#define PROCESSORBASE_H - -#include "engine/EngineExecutionContext.h" -#include "common/properties/PropertyID.h" - -namespace grndr -{ - class Engine; - class Connection; - - class ProcessorBase - { - public: - ProcessorBase(const Block* block); - - public: - const Engine* engine() const { return _engine; } - const Block* block() const { return _block; } - - public: - virtual void execute(EngineExecutionContext& ctx); - - protected: - DataDescriptor getPortDataDescriptor(const Port* port, unsigned int index = 0) const; - DataDescriptor getBestDataDescriptor(const DataDescriptor& dataDesc, const DataDescriptors& targetDescriptors) const; - - DataBlob* portData(EngineExecutionContext& ctx, const Port* port, const Connection* connection = nullptr, bool required = true, bool convert = true) const; - template<typename PropType> - PropType* portProperty(const Port* port, PropertyID propertyID, const Connection* connection = nullptr, bool required = true) const; - - protected: - bool isBlockBypassed() const; - cv::Mat getBypassedImage(const DataBlob* dataBlob, bool outputIsGrayscale = false); - - protected: - void throwProcessorException(QString what) const; - - private: - const Port* resolveDataPort(const Port* port, const Connection* connection, bool required) const; - - protected: - const Engine* _engine{nullptr}; - const Block* _block{nullptr}; - - private: - mutable std::vector<std::shared_ptr<DataBlob>> _portDataCache; - }; -} - -#include "ProcessorBase.impl.h" - -#endif +/****************************************************************************** + * File: ProcessorBase.h + * Date: 21.2.2018 + *****************************************************************************/ + +#ifndef PROCESSORBASE_H +#define PROCESSORBASE_H + +#include "pipeline/BlockType.h" +#include "engine/EngineExecutionContext.h" +#include "common/properties/PropertyID.h" + +namespace grndr +{ + class Engine; + class Connection; + + class ProcessorBase + { + public: + ProcessorBase(const Block* block); + + public: + const Engine* engine() const { return _engine; } + const Block* block() const { return _block; } + + public: + virtual void execute(EngineExecutionContext& ctx); + + protected: + DataDescriptor getPortDataDescriptor(const Port* port, unsigned int index = 0) const; + DataDescriptor getBestDataDescriptor(const DataDescriptor& dataDesc, const DataDescriptors& targetDescriptors) const; + + DataBlob* portData(EngineExecutionContext& ctx, const Port* port, const Connection* connection = nullptr, bool required = true, bool convert = true, bool persistentData = false) const; + template<typename ValueType> + ValueType portData(EngineExecutionContext& ctx, const Port* port, QString dataName, bool required = true, BlockType blockType = BlockType::Undefined, bool persistentData = false) const; + template<typename PropType> + PropType* portProperty(const Port* port, PropertyID propertyID, const Connection* connection = nullptr, bool required = true) const; + + protected: + bool isBlockBypassed() const; + cv::Mat getBypassedImage(const DataBlob* dataBlob, bool outputIsGrayscale = false); + + protected: + void throwProcessorException(QString what) const; + + private: + const Port* resolveDataPort(const Port* port, const Connection* connection, bool required) const; + + protected: + const Engine* _engine{nullptr}; + const Block* _block{nullptr}; + + private: + mutable std::vector<std::shared_ptr<DataBlob>> _portDataCache; + }; +} + +#include "ProcessorBase.impl.h" + +#endif diff --git a/Grinder/engine/ProcessorBase.impl.h b/Grinder/engine/ProcessorBase.impl.h index e6d3724482519e15cfb019228e7dd43ecce46241..fa7f8ae434b225e698565503308c8901aaa8bb17 100644 --- a/Grinder/engine/ProcessorBase.impl.h +++ b/Grinder/engine/ProcessorBase.impl.h @@ -1,24 +1,51 @@ -/****************************************************************************** - * File: ProcessorBase.impl.h - * Date: 20.4.2018 - *****************************************************************************/ - -#include "Grinder.h" -#include "ProcessorBase.h" -#include "pipeline/Block.h" - -template<typename PropType> -PropType* ProcessorBase::portProperty(const Port* port, PropertyID propertyID, const Connection* connection, bool required) const -{ - if (auto dataPort = resolveDataPort(port, connection, required)) - { - auto property = _block->portProperty<PropType>(dataPort, propertyID); - - if (!property && required) - throwProcessorException(QString{"No property with ID '%2' could be retrieved for port '%1'"}.arg(port->getName()).arg(propertyID)); - - return property; - } - else - return nullptr; -} +/****************************************************************************** + * File: ProcessorBase.impl.h + * Date: 20.4.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "ProcessorBase.h" +#include "pipeline/Block.h" + +template<typename ValueType> +ValueType ProcessorBase::portData(EngineExecutionContext& ctx, const Port* port, QString dataName, bool required, BlockType blockType, bool persistentData) const +{ + auto connections = port->getConnections(port->getDirection(), blockType); + + if (connections.size() == 1) + { + auto sourcePort = connections[0]->sourcePort(); + + if (persistentData) + { + if (ctx.persistentData().contains(sourcePort, dataName)) + return ctx.persistentData().get<ValueType>(sourcePort, dataName); + } + else + { + if (ctx.data().contains(sourcePort, dataName)) + return ctx.data().get<ValueType>(sourcePort, dataName); + } + } + + if (required) + throwProcessorException(QString{"Data '%2' could be retrieved for port '%1'"}.arg(port->getName()).arg(dataName)); + + return {}; +} + +template<typename PropType> +PropType* ProcessorBase::portProperty(const Port* port, PropertyID propertyID, const Connection* connection, bool required) const +{ + if (auto dataPort = resolveDataPort(port, connection, required)) + { + auto property = _block->portProperty<PropType>(dataPort, propertyID); + + if (!property && required) + throwProcessorException(QString{"No property with ID '%2' could be retrieved for port '%1'"}.arg(port->getName()).arg(propertyID)); + + return property; + } + else + return nullptr; +} diff --git a/Grinder/engine/processors/GrabCutProcessor.cpp b/Grinder/engine/processors/GrabCutProcessor.cpp index 99d13cf9af80477279d81e2138ce07da7e271a3d..abde6ea39f4656d41d080112119016b8d4d0c3d9 100644 --- a/Grinder/engine/processors/GrabCutProcessor.cpp +++ b/Grinder/engine/processors/GrabCutProcessor.cpp @@ -37,21 +37,19 @@ void GrabCutProcessor::execute(EngineExecutionContext& ctx) { if (!isBlockBypassed()) { - cv::Mat mask; - cv::Mat backgroundModel; - cv::Mat foregroundModel; - // Get any preceding and succeeding GrabCut blocks auto predecessors = _block->predecessorPort()->getConnections(Port::Direction::In, BlockType::GrabCut); auto successors = _block->successorPort()->getConnections(Port::Direction::Out, BlockType::GrabCut); - if (predecessors.size() == 1) // Initialize the GrabCut algorithm using the data from the predecessor - { - auto predecessor = predecessors[0]->sourcePort(); + // Initialize the GrabCut algorithm using the data from the predecessor + cv::Mat backgroundModel = portData<cv::Mat>(ctx, _block->predecessorPort(), Data_Value_BGModel, false, BlockType::GrabCut); + cv::Mat foregroundModel = portData<cv::Mat>(ctx, _block->predecessorPort(), Data_Value_FGModel, false, BlockType::GrabCut); + cv::Mat mask; + + if (predecessors.size() == 1) + { mask = generateGrabCutMask(maskBlob->getMatrix(), false); - backgroundModel = ctx.data().get<cv::Mat>(predecessor, Data_Value_BGModel); - foregroundModel = ctx.data().get<cv::Mat>(predecessor, Data_Value_FGModel); } else // Initialize the GrabCut algorithm from scratch { diff --git a/Grinder/ml/MachineLearningMethod.cpp b/Grinder/ml/MachineLearningMethod.cpp deleted file mode 100644 index 0069233d320f4d8e187b0f5fe66e0f76095a0e83..0000000000000000000000000000000000000000 --- a/Grinder/ml/MachineLearningMethod.cpp +++ /dev/null @@ -1,7 +0,0 @@ -/****************************************************************************** - * File: MachineLearningMethod.cpp - * Date: 13.8.2019 - *****************************************************************************/ - -#include "Grinder.h" -#include "MachineLearningMethod.h" diff --git a/Grinder/ml/MachineLearningMethod.h b/Grinder/ml/MachineLearningMethod.h index 3edec2dc57f9942d547ce329b22421b9f879924f..0f00caa38be4a94993b55d224d2005616928c277 100644 --- a/Grinder/ml/MachineLearningMethod.h +++ b/Grinder/ml/MachineLearningMethod.h @@ -23,7 +23,10 @@ namespace grndr using spawner_type = SpawnerType; public: - virtual std::unique_ptr<MachineLearningTaskSpawnerBase> createTaskSpawner(InvocationMode mode) override; + using MachineLearningMethodBase::MachineLearningMethodBase; + + public: + virtual std::unique_ptr<MachineLearningTaskSpawnerBase> createTaskSpawner() const override; public: config_type& config() { return _config; } diff --git a/Grinder/ml/MachineLearningMethod.impl.h b/Grinder/ml/MachineLearningMethod.impl.h index 04ebec1e4d80665b8b4d1548cd31e23985bb4a19..f531b655579d0ce1ddbfa52f01805f4bd5e52fd0 100644 --- a/Grinder/ml/MachineLearningMethod.impl.h +++ b/Grinder/ml/MachineLearningMethod.impl.h @@ -8,7 +8,7 @@ #include "MachineLearningExceptions.h" template<typename ConfigType, typename SpawnerType> -std::unique_ptr<MachineLearningTaskSpawnerBase> MachineLearningMethod<ConfigType, SpawnerType>::createTaskSpawner(InvocationMode mode) +std::unique_ptr<MachineLearningTaskSpawnerBase> MachineLearningMethod<ConfigType, SpawnerType>::createTaskSpawner() const { - return std::make_unique<spawner_type>(mode, _config); + return std::make_unique<spawner_type>(_config); } diff --git a/Grinder/ml/MachineLearningMethodBase.cpp b/Grinder/ml/MachineLearningMethodBase.cpp index 43e3aa49f9e7188acf47a14da14fc3ef9d712b7d..125e832ac9a2d42674c6f5230e55b5d831273d3f 100644 --- a/Grinder/ml/MachineLearningMethodBase.cpp +++ b/Grinder/ml/MachineLearningMethodBase.cpp @@ -5,3 +5,9 @@ #include "Grinder.h" #include "MachineLearningMethodBase.h" + +MachineLearningMethodBase::MachineLearningMethodBase(QString methodName) : + _methodName{methodName} +{ + +} diff --git a/Grinder/ml/MachineLearningMethodBase.h b/Grinder/ml/MachineLearningMethodBase.h index 14d688985ba806b94aad2b2d6e4943a020ed5a10..6302def716a5808198eb795bd942b4c2a0bd2938 100644 --- a/Grinder/ml/MachineLearningMethodBase.h +++ b/Grinder/ml/MachineLearningMethodBase.h @@ -18,17 +18,19 @@ namespace grndr Q_OBJECT public: - enum class InvocationMode - { - Training, - Inference, - }; + MachineLearningMethodBase(QString methodName); public: - virtual std::unique_ptr<MachineLearningTaskSpawnerBase> createTaskSpawner(InvocationMode mode) = 0; + virtual std::unique_ptr<MachineLearningTaskSpawnerBase> createTaskSpawner() const = 0; public: virtual QStringList getAvailableStates() const = 0; + + public: + QString getMethodName() const { return _methodName; } + + protected: + QString _methodName{""}; }; } diff --git a/Grinder/ml/MachineLearningTaskSpawner.h b/Grinder/ml/MachineLearningTaskSpawner.h index 7a0d08920bb186634281a878f69a548bf3236014..b99048dba3c8523dfa585a744768b8ff34ec5b16 100644 --- a/Grinder/ml/MachineLearningTaskSpawner.h +++ b/Grinder/ml/MachineLearningTaskSpawner.h @@ -7,21 +7,42 @@ #define MACHINELEARNINGTASKSPAWNER_H #include "MachineLearningTaskSpawnerBase.h" +#include "task/TaskType.h" namespace grndr { class MachineLearningConfiguration; + class Task; - template<typename ConfigType> + template<typename ConfigType, typename TrainingTaskType, typename InferenceTaskType> class MachineLearningTaskSpawner : public MachineLearningTaskSpawnerBase { static_assert(std::is_base_of<MachineLearningConfiguration, ConfigType>::value, "ConfigType must be derived from MachineLearningConfiguration"); + static_assert(std::is_base_of<Task, TrainingTaskType>::value, "TrainingTaskType must be derived from Task"); + static_assert(std::is_base_of<Task, InferenceTaskType>::value, "TrainingTaskType must be derived from Task"); public: using config_type = ConfigType; + using training_task_type = TrainingTaskType; + using inference_task_type = InferenceTaskType; public: - MachineLearningTaskSpawner(MachineLearningMethodBase::InvocationMode mode, const config_type& config); + MachineLearningTaskSpawner(const config_type& config); + + public: + virtual std::shared_ptr<Task> spawnTrainingTask(QString state, QString name) const override; + virtual std::shared_ptr<Task> spawnInferenceTask(QString state, QString name) const override; + + protected: + virtual void configureTrainingTask(training_task_type* task, QString state) const = 0; + virtual void configureInferenceTask(inference_task_type* task, QString state) const = 0; + + private: + template<typename TaskType> + std::shared_ptr<TaskType> createTask(QString name) const; + + std::shared_ptr<training_task_type> createTrainingTask(QString name) const { return createTask<training_task_type>(name); } + std::shared_ptr<inference_task_type> createInferenceTask(QString name) const { return createTask<inference_task_type>(name); } protected: const config_type& _config; diff --git a/Grinder/ml/MachineLearningTaskSpawner.impl.h b/Grinder/ml/MachineLearningTaskSpawner.impl.h index 0ac03a6c79f6ea1997b445d0a91f4a6c8df08240..8a87f14d4b17df564c3f20b7a8abd2dba2ccad96 100644 --- a/Grinder/ml/MachineLearningTaskSpawner.impl.h +++ b/Grinder/ml/MachineLearningTaskSpawner.impl.h @@ -5,10 +5,49 @@ #include "Grinder.h" #include "MachineLearningTaskSpawner.h" +#include "MachineLearningExceptions.h" +#include "core/GrinderApplication.h" -template<typename ConfigType> -MachineLearningTaskSpawner<ConfigType>::MachineLearningTaskSpawner(MachineLearningMethodBase::InvocationMode mode, const MachineLearningTaskSpawner::config_type& config) : MachineLearningTaskSpawnerBase(mode), +template<typename ConfigType, typename TrainingTaskType, typename InferenceTaskType> +MachineLearningTaskSpawner<ConfigType, TrainingTaskType, InferenceTaskType>::MachineLearningTaskSpawner(const config_type& config) : _config{config} { } + +template<typename ConfigType, typename TrainingTaskType, typename InferenceTaskType> +std::shared_ptr<Task> MachineLearningTaskSpawner<ConfigType, TrainingTaskType, InferenceTaskType>::spawnTrainingTask(QString state, QString name) const +{ + // If any of the following functions fails, an exception will be thrown + auto task = createTrainingTask(name); + configureTrainingTask(task.get(), state); + + auto sharedTask = std::dynamic_pointer_cast<Task>(task); + grinder()->taskController().addTask(sharedTask); + return sharedTask; +} + +template<typename ConfigType, typename TrainingTaskType, typename InferenceTaskType> +std::shared_ptr<Task> MachineLearningTaskSpawner<ConfigType, TrainingTaskType, InferenceTaskType>::spawnInferenceTask(QString state, QString name) const +{ + // If any of the following functions fails, an exception will be thrown + auto task = createInferenceTask(name); + configureInferenceTask(task.get(), state); + + auto sharedTask = std::dynamic_pointer_cast<Task>(task); + grinder()->taskController().addTask(sharedTask); + return sharedTask; +} + +template<typename ConfigType, typename TrainingTaskType, typename InferenceTaskType> +template<typename TaskType> +std::shared_ptr<TaskType> MachineLearningTaskSpawner<ConfigType, TrainingTaskType, InferenceTaskType>::createTask(QString name) const +{ + if (auto task = grinder()->taskController().createTask(TaskType::type_value, name, false)) + { + if (auto typedTask = std::dynamic_pointer_cast<TaskType>(task)) + return typedTask; + } + + throw MachineLearningException{_EXCPT(QString{"Unable to create a task of type '%1'"}.arg(TaskType::type_value))}; +} diff --git a/Grinder/ml/MachineLearningTaskSpawnerBase.cpp b/Grinder/ml/MachineLearningTaskSpawnerBase.cpp index 15f258b8a2f1282693ceec441299f739256f5ad3..1ac396eb032131af17d95c05b45f18519fae18f9 100644 --- a/Grinder/ml/MachineLearningTaskSpawnerBase.cpp +++ b/Grinder/ml/MachineLearningTaskSpawnerBase.cpp @@ -5,9 +5,3 @@ #include "Grinder.h" #include "MachineLearningTaskSpawnerBase.h" - -MachineLearningTaskSpawnerBase::MachineLearningTaskSpawnerBase(MachineLearningMethodBase::InvocationMode mode) : - _invocationMode{mode} -{ - -} diff --git a/Grinder/ml/MachineLearningTaskSpawnerBase.h b/Grinder/ml/MachineLearningTaskSpawnerBase.h index ab0926e45f6e9aec84c713d328fe1aa416ec93dc..669a46a2cdb25b6056a0a4c07f9128111072b5f0 100644 --- a/Grinder/ml/MachineLearningTaskSpawnerBase.h +++ b/Grinder/ml/MachineLearningTaskSpawnerBase.h @@ -10,13 +10,13 @@ namespace grndr { + class Task; + class MachineLearningTaskSpawnerBase { public: - MachineLearningTaskSpawnerBase(MachineLearningMethodBase::InvocationMode mode); - - protected: - MachineLearningMethodBase::InvocationMode _invocationMode{MachineLearningMethodBase::InvocationMode::Training}; + virtual std::shared_ptr<Task> spawnTrainingTask(QString state, QString name) const = 0; + virtual std::shared_ptr<Task> spawnInferenceTask(QString state, QString name) const = 0; }; } diff --git a/Grinder/ml/barista/BaristaClassifierMethod.cpp b/Grinder/ml/barista/BaristaClassifierMethod.cpp index f73cf0370301e0322d76bbf9ff43dbb9306c6d79..d6c8552924cfd58807a86d3da383a510641aef6d 100644 --- a/Grinder/ml/barista/BaristaClassifierMethod.cpp +++ b/Grinder/ml/barista/BaristaClassifierMethod.cpp @@ -6,6 +6,11 @@ #include "Grinder.h" #include "BaristaClassifierMethod.h" +BaristaClassifierMethod::BaristaClassifierMethod() : MachineLearningMethod("Barista") +{ + +} + QStringList BaristaClassifierMethod::getAvailableStates() const { QString outputDir = _config.getOutputDirectory(); diff --git a/Grinder/ml/barista/BaristaClassifierMethod.h b/Grinder/ml/barista/BaristaClassifierMethod.h index 1da882e12932539839c258a412b9b9a1801129be..31dba6428993b83b91a6f9581c66d6851b2df16f 100644 --- a/Grinder/ml/barista/BaristaClassifierMethod.h +++ b/Grinder/ml/barista/BaristaClassifierMethod.h @@ -16,6 +16,9 @@ namespace grndr { Q_OBJECT + public: + BaristaClassifierMethod(); + public: virtual QStringList getAvailableStates() const override; }; diff --git a/Grinder/ml/barista/BaristaClassifierTaskSpawner.cpp b/Grinder/ml/barista/BaristaClassifierTaskSpawner.cpp index 3f43e2142a4352962ec647923f25fa14e6c81387..6e6d63fa2d7d09a4bc366dce9a866025f2ec29aa 100644 --- a/Grinder/ml/barista/BaristaClassifierTaskSpawner.cpp +++ b/Grinder/ml/barista/BaristaClassifierTaskSpawner.cpp @@ -5,3 +5,23 @@ #include "Grinder.h" #include "BaristaClassifierTaskSpawner.h" + +void BaristaClassifierTaskSpawner::configureTrainingTask(BaristaTrainingTask* task, QString state) const +{ + Q_UNUSED(state); + + configureTask(task); + + // Configure training specific settings + task->setMaxIterations(_config.getMaxIterations()); + task->setDisplayInterval(_config.getDisplayInterval()); + task->setSnapshotInterval(_config.getSnapshotInterval()); +} + +void BaristaClassifierTaskSpawner::configureInferenceTask(BaristaInferenceTask* task, QString state) const +{ + configureTask(task); + + // Configure inference specific settings + task->setNetworkStateFile(state); +} diff --git a/Grinder/ml/barista/BaristaClassifierTaskSpawner.h b/Grinder/ml/barista/BaristaClassifierTaskSpawner.h index d21073fc1dfdf008174690a58352d8ba7e65efa0..fee72d74fe6b417212283b52e43780f67858f3b6 100644 --- a/Grinder/ml/barista/BaristaClassifierTaskSpawner.h +++ b/Grinder/ml/barista/BaristaClassifierTaskSpawner.h @@ -8,14 +8,26 @@ #include "ml/MachineLearningTaskSpawner.h" #include "BaristaClassifierConfiguration.h" +#include "tasks/BaristaTrainingTask.h" +#include "tasks/BaristaInferenceTask.h" namespace grndr { - class BaristaClassifierTaskSpawner : public MachineLearningTaskSpawner<BaristaClassifierConfiguration> + class BaristaClassifierTaskSpawner : public MachineLearningTaskSpawner<BaristaClassifierConfiguration, BaristaTrainingTask, BaristaInferenceTask> { public: - using MachineLearningTaskSpawner<BaristaClassifierConfiguration>::MachineLearningTaskSpawner; + using MachineLearningTaskSpawner<BaristaClassifierConfiguration, BaristaTrainingTask, BaristaInferenceTask>::MachineLearningTaskSpawner; + + protected: + virtual void configureTrainingTask(BaristaTrainingTask* task, QString state) const override; + virtual void configureInferenceTask(BaristaInferenceTask* task, QString state) const override; + + private: + template<typename TaskType> + void configureTask(TaskType* task) const; }; } +#include "BaristaClassifierTaskSpawner.impl.h" + #endif diff --git a/Grinder/ml/barista/BaristaClassifierTaskSpawner.impl.h b/Grinder/ml/barista/BaristaClassifierTaskSpawner.impl.h new file mode 100644 index 0000000000000000000000000000000000000000..57907f282ba7ffdd0a4162207cbfca564847f4e5 --- /dev/null +++ b/Grinder/ml/barista/BaristaClassifierTaskSpawner.impl.h @@ -0,0 +1,19 @@ +/****************************************************************************** + * File: BaristaClassifierTaskSpawner.impl.h + * Date: 14.8.2019 + *****************************************************************************/ + +#include "Grinder.h" +#include "BaristaClassifierTaskSpawner.h" + +template<typename TaskType> +void BaristaClassifierTaskSpawner::configureTask(TaskType* task) const +{ + // Apply general settings to the task + task->setBaristaPort(_config.getBaristaPort()); + task->setLibraryPath(_config.getLibraryPath()); + + task->setNetwork(_config.network()); + task->setOutputDirectory(_config.getOutputDirectory()); + task->setRemoteDirectory(_config.getRemoteDirectory()); +} diff --git a/Grinder/ml/barista/blocks/BaristaClassifierBlock.cpp b/Grinder/ml/barista/blocks/BaristaClassifierBlock.cpp index 948e374bd105c4e51300e7cc2e3529b02ee81cb7..ce428c5d9710f3ef619612e829550c44f2e8285d 100644 --- a/Grinder/ml/barista/blocks/BaristaClassifierBlock.cpp +++ b/Grinder/ml/barista/blocks/BaristaClassifierBlock.cpp @@ -9,14 +9,14 @@ const BlockCategory BaristaClassifierBlock::category_value = BlockCategory::MachineLearning; const BlockType BaristaClassifierBlock::type_value = BlockType::BaristaClassifier; -BaristaClassifierBlock::BaristaClassifierBlock(Pipeline* pipeline, QString name) : MachineLearningBlock(pipeline, type_value, category_value, name) +BaristaClassifierBlock::BaristaClassifierBlock(Pipeline* pipeline, QString name) : MachineLearningMethodBlock(pipeline, type_value, category_value, name) { } void BaristaClassifierBlock::createProperties() { - MachineLearningBlock::createProperties(); + MachineLearningMethodBlock::createProperties(); setPropertyGroup("General"); @@ -55,7 +55,7 @@ void BaristaClassifierBlock::createProperties() bool BaristaClassifierBlock::updateProperties(PropertyBase* updatedProp) { - bool updated = MachineLearningBlock::updateProperties(updatedProp); + bool updated = MachineLearningMethodBlock::updateProperties(updatedProp); // If the output directory has been modified, update the currently selected state if (updatedProp == _outputDirectory.get()) diff --git a/Grinder/ml/barista/blocks/BaristaClassifierBlock.h b/Grinder/ml/barista/blocks/BaristaClassifierBlock.h index 97e40b74bee1d7ff51ab5d54ff35a303a0910c84..50c43074221ed1122ac91391615347646d980e93 100644 --- a/Grinder/ml/barista/blocks/BaristaClassifierBlock.h +++ b/Grinder/ml/barista/blocks/BaristaClassifierBlock.h @@ -6,13 +6,13 @@ #ifndef BARISTACLASSIFIERBLOCK_H #define BARISTACLASSIFIERBLOCK_H -#include "ml/blocks/MachineLearningBlock.h" +#include "ml/blocks/MachineLearningMethodBlock.h" #include "ml/barista/BaristaClassifierMethod.h" #include "ml/barista/properties/BaristaNetworkProperty.h" namespace grndr { - class BaristaClassifierBlock : public MachineLearningBlock<BaristaClassifierMethod> + class BaristaClassifierBlock : public MachineLearningMethodBlock<BaristaClassifierMethod> { Q_OBJECT diff --git a/Grinder/ml/blocks/MachineLearningBlock.cpp b/Grinder/ml/blocks/MachineLearningBlock.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1841dda4ba04943e9e4450fb569cb1b38a30af28 --- /dev/null +++ b/Grinder/ml/blocks/MachineLearningBlock.cpp @@ -0,0 +1,19 @@ +/****************************************************************************** + * File: MachineLearningBlock.cpp + * Date: 13.8.2019 + *****************************************************************************/ + +#include "Grinder.h" +#include "MachineLearningBlock.h" + +void MachineLearningBlock::createPorts() +{ + DataDescriptors methodPortDataDescs = {DataDescriptor::customDescriptor("Machine learning method", DataType::MachineLearningMethod)}; + _methodPort = createPort(PortType::Method, Port::Direction::In, methodPortDataDescs, "Method"); + + DataDescriptors statePortDataDescs = {DataDescriptor::customDescriptor("Model state", DataType::MachineLearningState)}; + _statePort = createPort(PortType::State, Port::Direction::In, statePortDataDescs, "State"); + + DataDescriptors inPortDataDescs = {DataDescriptor::imageDescriptor(true, DataDescriptor::ValueType::Any), DataDescriptor::imageDescriptor(false, DataDescriptor::ValueType::Any)}; + _inPort = createPort(PortType::ImageIn, Port::Direction::In, inPortDataDescs, "In"); +} diff --git a/Grinder/ml/blocks/MachineLearningBlock.h b/Grinder/ml/blocks/MachineLearningBlock.h index 5fd548017222366f561db4880fa9c0a71fd473a9..e790f04f353b2866dc2737c62fa946cec983a086 100644 --- a/Grinder/ml/blocks/MachineLearningBlock.h +++ b/Grinder/ml/blocks/MachineLearningBlock.h @@ -7,64 +7,32 @@ #define MACHINELEARNINGBLOCK_H #include "pipeline/Block.h" -#include "ml/properties/MachineLearningStateProperty.h" namespace grndr { - class MachineLearningMethodBase; - - template<typename MethodType> class MachineLearningBlock : public Block { - static_assert(std::is_base_of<MachineLearningMethodBase, MethodType>::value, "MethodType must be derived from MachineLearningMethodBase"); - - public: - using method_type = MethodType; - - public: - MachineLearningBlock(Pipeline* pipeline, BlockType type, BlockCategory category, QString name = ""); + Q_OBJECT public: - virtual void initBlock() override; + using Block::Block; public: - method_type& method() { return _method; } - const method_type& method() const { return _method; } - - public: - virtual std::unique_ptr<ProcessorBase> createProcessor() const override; - - public: - auto state() { return dynamic_cast<MachineLearningStateProperty*>(_state.get()); } - auto state() const { return dynamic_cast<const MachineLearningStateProperty*>(_state.get()); } - Port* methodPort() { return _methodPort.get(); } const Port* methodPort() const { return _methodPort.get(); } Port* statePort() { return _statePort.get(); } const Port* statePort() const { return _statePort.get(); } + Port* inPort() { return _inPort.get(); } + const Port* inPort() const { return _inPort.get(); } protected: - virtual void createProperties() override; - virtual bool updateProperties(PropertyBase* updatedProp = nullptr) override; virtual void createPorts() override; - protected: - virtual void updateConfiguration() = 0; - - protected: - bool checkStateAvailability(); - - protected: - method_type _method; - - protected: - std::shared_ptr<PropertyBase> _state; - + private: std::shared_ptr<Port> _methodPort; std::shared_ptr<Port> _statePort; + std::shared_ptr<Port> _inPort; }; } -#include "MachineLearningBlock.impl.h" - #endif diff --git a/Grinder/ml/blocks/MachineLearningMethodBlock.h b/Grinder/ml/blocks/MachineLearningMethodBlock.h new file mode 100644 index 0000000000000000000000000000000000000000..9cf873b7bb3083c79bf1d86ad0593ed5399a8474 --- /dev/null +++ b/Grinder/ml/blocks/MachineLearningMethodBlock.h @@ -0,0 +1,70 @@ +/****************************************************************************** + * File: MachineLearningMethodBlock.h + * Date: 13.8.2019 + *****************************************************************************/ + +#ifndef MACHINELEARNINGMETHODBLOCK_H +#define MACHINELEARNINGMETHODBLOCK_H + +#include "pipeline/Block.h" +#include "ml/properties/MachineLearningStateProperty.h" + +namespace grndr +{ + class MachineLearningMethodBase; + + template<typename MethodType> + class MachineLearningMethodBlock : public Block + { + static_assert(std::is_base_of<MachineLearningMethodBase, MethodType>::value, "MethodType must be derived from MachineLearningMethodBase"); + + public: + using method_type = MethodType; + + public: + MachineLearningMethodBlock(Pipeline* pipeline, BlockType type, BlockCategory category, QString name = ""); + + public: + virtual void initBlock() override; + + public: + method_type& method() { return _method; } + const method_type& method() const { return _method; } + + public: + virtual std::unique_ptr<ProcessorBase> createProcessor() const override; + + public: + auto state() { return dynamic_cast<MachineLearningStateProperty*>(_state.get()); } + auto state() const { return dynamic_cast<const MachineLearningStateProperty*>(_state.get()); } + + Port* methodPort() { return _methodPort.get(); } + const Port* methodPort() const { return _methodPort.get(); } + Port* statePort() { return _statePort.get(); } + const Port* statePort() const { return _statePort.get(); } + + protected: + virtual void createProperties() override; + virtual bool updateProperties(PropertyBase* updatedProp = nullptr) override; + virtual void createPorts() override; + + protected: + virtual void updateConfiguration() = 0; + + protected: + bool checkStateAvailability(); + + protected: + method_type _method; + + protected: + std::shared_ptr<PropertyBase> _state; + + std::shared_ptr<Port> _methodPort; + std::shared_ptr<Port> _statePort; + }; +} + +#include "MachineLearningMethodBlock.impl.h" + +#endif diff --git a/Grinder/ml/blocks/MachineLearningBlock.impl.h b/Grinder/ml/blocks/MachineLearningMethodBlock.impl.h similarity index 62% rename from Grinder/ml/blocks/MachineLearningBlock.impl.h rename to Grinder/ml/blocks/MachineLearningMethodBlock.impl.h index c9e17bd7e62ba13b37051d6fc3c0743a132d149f..fef7a1847ed8161add7cce1283b3124323b95fbe 100644 --- a/Grinder/ml/blocks/MachineLearningBlock.impl.h +++ b/Grinder/ml/blocks/MachineLearningMethodBlock.impl.h @@ -1,20 +1,20 @@ /****************************************************************************** - * File: MachineLearningBlock.impl.h + * File: MachineLearningMethodBlock.impl.h * Date: 13.8.2019 *****************************************************************************/ #include "Grinder.h" -#include "MachineLearningBlock.h" -#include "ml/processors/MachineLearningProcessor.h" +#include "MachineLearningMethodBlock.h" +#include "ml/processors/MachineLearningMethodProcessor.h" template<typename MethodType> -MachineLearningBlock<MethodType>::MachineLearningBlock(Pipeline* pipeline, BlockType type, BlockCategory category, QString name) : Block(pipeline, type, category, name) +MachineLearningMethodBlock<MethodType>::MachineLearningMethodBlock(Pipeline* pipeline, BlockType type, BlockCategory category, QString name) : Block(pipeline, type, category, name) { _bypassPossible = false; } template<typename MethodType> -void MachineLearningBlock<MethodType>::initBlock() +void MachineLearningMethodBlock<MethodType>::initBlock() { Block::initBlock(); @@ -22,13 +22,13 @@ void MachineLearningBlock<MethodType>::initBlock() } template<typename MethodType> -std::unique_ptr<ProcessorBase> MachineLearningBlock<MethodType>::createProcessor() const +std::unique_ptr<ProcessorBase> MachineLearningMethodBlock<MethodType>::createProcessor() const { - return std::make_unique<MachineLearningProcessor<method_type>>(this); + return std::make_unique<MachineLearningMethodProcessor<method_type>>(this); } template<typename MethodType> -void MachineLearningBlock<MethodType>::createProperties() +void MachineLearningMethodBlock<MethodType>::createProperties() { Block::createProperties(); @@ -41,7 +41,7 @@ void MachineLearningBlock<MethodType>::createProperties() } template<typename MethodType> -bool MachineLearningBlock<MethodType>::updateProperties(PropertyBase* updatedProp) +bool MachineLearningMethodBlock<MethodType>::updateProperties(PropertyBase* updatedProp) { bool updated = Block::updateProperties(updatedProp); @@ -51,7 +51,7 @@ bool MachineLearningBlock<MethodType>::updateProperties(PropertyBase* updatedPro } template<typename MethodType> -void MachineLearningBlock<MethodType>::createPorts() +void MachineLearningMethodBlock<MethodType>::createPorts() { DataDescriptors methodPortDataDescs = {DataDescriptor::customDescriptor("Machine learning method", DataType::MachineLearningMethod)}; _methodPort = createPort(PortType::Method, Port::Direction::Out, methodPortDataDescs, "Method"); @@ -61,7 +61,7 @@ void MachineLearningBlock<MethodType>::createPorts() } template<typename MethodType> -bool MachineLearningBlock<MethodType>::checkStateAvailability() +bool MachineLearningMethodBlock<MethodType>::checkStateAvailability() { auto states = _method.getAvailableStates(); diff --git a/Grinder/ml/blocks/TrainingBlock.cpp b/Grinder/ml/blocks/TrainingBlock.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8f3af15134df1dfb0bb097244aec823559794a3b --- /dev/null +++ b/Grinder/ml/blocks/TrainingBlock.cpp @@ -0,0 +1,31 @@ +/****************************************************************************** + * File: TrainingBlock.cpp + * Date: 13.8.2019 + *****************************************************************************/ + +#include "Grinder.h" +#include "TrainingBlock.h" +#include "ml/processors/TrainingProcessor.h" + +const BlockType TrainingBlock::type_value = BlockType::Training; +const BlockCategory TrainingBlock::category_value = BlockCategory::MachineLearning; + +TrainingBlock::TrainingBlock(Pipeline* pipeline, QString name) : MachineLearningBlock(pipeline, type_value, category_value, name) +{ + +} + +std::unique_ptr<ProcessorBase> TrainingBlock::createProcessor() const +{ + return std::make_unique<TrainingProcessor>(this); +} + +void TrainingBlock::createPorts() +{ + MachineLearningBlock::createPorts(); + + _tagsBitmapPort = createPort(PortType::ImageTagsBitmap, Port::Direction::In, {DataDescriptor::imageDescriptor(false, DataDescriptor::ValueType::UInt32)}, "Tags bitmap"); + + DataDescriptors trainedStatePortDataDescs = {DataDescriptor::customDescriptor("Trained state", DataType::MachineLearningState)}; + _trainedStatePort = createPort(PortType::TrainedState, Port::Direction::Out, trainedStatePortDataDescs, "State"); +} diff --git a/Grinder/ml/blocks/TrainingBlock.h b/Grinder/ml/blocks/TrainingBlock.h new file mode 100644 index 0000000000000000000000000000000000000000..05e136848fe194af9450e03a1c85432649a380a8 --- /dev/null +++ b/Grinder/ml/blocks/TrainingBlock.h @@ -0,0 +1,42 @@ +/****************************************************************************** + * File: TrainingBlock.h + * Date: 13.8.2019 + *****************************************************************************/ + +#ifndef TRAININGBLOCK_H +#define TRAININGBLOCK_H + +#include "MachineLearningBlock.h" + +namespace grndr +{ + class TrainingBlock : public MachineLearningBlock + { + Q_OBJECT + + public: + static const BlockType type_value; + static const BlockCategory category_value; + + public: + TrainingBlock(Pipeline* pipeline, QString name = ""); + + public: + virtual std::unique_ptr<ProcessorBase> createProcessor() const override; + + public: + Port* tagsBitmapPort() { return _tagsBitmapPort.get(); } + const Port* tagsBitmapPort() const { return _tagsBitmapPort.get(); } + Port* trainedStatePort() { return _trainedStatePort.get(); } + const Port* trainedStatePort() const { return _trainedStatePort.get(); } + + protected: + virtual void createPorts() override; + + private: + std::shared_ptr<Port> _tagsBitmapPort; + std::shared_ptr<Port> _trainedStatePort; + }; +} + +#endif diff --git a/Grinder/ml/processors/MachineLearningMethodProcessor.h b/Grinder/ml/processors/MachineLearningMethodProcessor.h new file mode 100644 index 0000000000000000000000000000000000000000..0a14e0a87c42b05c98a01f33b4d1fb101c21792d --- /dev/null +++ b/Grinder/ml/processors/MachineLearningMethodProcessor.h @@ -0,0 +1,38 @@ +/****************************************************************************** + * File: MachineLearningMethodProcessor.h + * Date: 13.8.2019 + *****************************************************************************/ + +#ifndef MACHINELEARNINGMETHODPROCESSOR_H +#define MACHINELEARNINGMETHODPROCESSOR_H + +#include "engine/Processor.h" +#include "ml/blocks/MachineLearningMethodBlock.h" + +namespace grndr +{ + class MachineLearningMethodBase; + + template<typename MethodType> + class MachineLearningMethodProcessor : public Processor<MachineLearningMethodBlock<MethodType>> + { + static_assert(std::is_base_of<MachineLearningMethodBase, MethodType>::value, "MethodType must be derived from MachineLearningMethodBase"); + + private: + static const char* Data_Value_Method; + static const char* Data_Value_State; + + public: + using method_type = MethodType; + + public: + MachineLearningMethodProcessor(const Block* block); + + public: + virtual void execute(EngineExecutionContext& ctx) override; + }; +} + +#include "MachineLearningMethodProcessor.impl.h" + +#endif diff --git a/Grinder/ml/processors/MachineLearningMethodProcessor.impl.h b/Grinder/ml/processors/MachineLearningMethodProcessor.impl.h new file mode 100644 index 0000000000000000000000000000000000000000..9c88e628246150dbc2789d957cdc8f1e35637ee5 --- /dev/null +++ b/Grinder/ml/processors/MachineLearningMethodProcessor.impl.h @@ -0,0 +1,35 @@ +/****************************************************************************** + * File: MachineLearningMethodProcessor.impl.h + * Date: 13.8.2019 + *****************************************************************************/ + +#include "Grinder.h" +#include "MachineLearningMethodProcessor.h" +#include "ml/MachineLearningMethodBase.h" + +Q_DECLARE_METATYPE(const MachineLearningMethodBase*) + +template<typename MethodType> +const char* MachineLearningMethodProcessor<MethodType>::Data_Value_Method = "Method"; +template<typename MethodType> +const char* MachineLearningMethodProcessor<MethodType>::Data_Value_State = "State"; + +template<typename MethodType> +MachineLearningMethodProcessor<MethodType>::MachineLearningMethodProcessor(const Block* block) : Processor<MachineLearningMethodBlock<MethodType>>(block) +{ + +} + +template<typename MethodType> +void MachineLearningMethodProcessor<MethodType>::execute(EngineExecutionContext& ctx) +{ + if (ctx.isFirstImage()) + { + // Verify all configuration settings + this->_block->method().config().verifyConfiguration(); + } + + // Store the machine learning method and model state in the context data + ctx.data().set(this->_block->methodPort(), Data_Value_Method, static_cast<const MachineLearningMethodBase*>(&this->_block->method())); + ctx.data().set(this->_block->statePort(), Data_Value_State, this->_block->state()->getValue()); +} diff --git a/Grinder/ml/processors/MachineLearningProcessor.h b/Grinder/ml/processors/MachineLearningProcessor.h index 758aa28c741f6fd5c52dc22aab7c44a632f8e5b3..067c813cd2f33a94d0c5f3394d6a1b197e1134ca 100644 --- a/Grinder/ml/processors/MachineLearningProcessor.h +++ b/Grinder/ml/processors/MachineLearningProcessor.h @@ -7,29 +7,43 @@ #define MACHINELEARNINGPROCESSOR_H #include "engine/Processor.h" -#include "ml/blocks/MachineLearningBlock.h" namespace grndr { + class MachineLearningBlock; class MachineLearningMethodBase; + class Task; - template<typename MethodType> - class MachineLearningProcessor : public Processor<MachineLearningBlock<MethodType>> + template<typename BlockType> + class MachineLearningProcessor : public Processor<BlockType> { - static_assert(std::is_base_of<MachineLearningMethodBase, MethodType>::value, "MethodType must be derived from MachineLearningMethodBase"); + static_assert(std::is_base_of<MachineLearningBlock, BlockType>::value, "BlockType must be derived from MachineLearningBlock"); - public: + private: static const char* Data_Value_Method; static const char* Data_Value_State; - public: - using method_type = MethodType; + protected: + enum class SpawnType + { + Training, + Inference, + }; public: - MachineLearningProcessor(const Block* block); + using Processor<BlockType>::Processor; public: virtual void execute(EngineExecutionContext& ctx) override; + + protected: + virtual void execute(EngineExecutionContext& ctx, const MachineLearningMethodBase* method, QString state) = 0; + + protected: + std::shared_ptr<Task> spawnTask(SpawnType type, const MachineLearningMethodBase* method, QString state) const; + + private: + QString getSpawnTypeName(SpawnType type) const; }; } diff --git a/Grinder/ml/processors/MachineLearningProcessor.impl.h b/Grinder/ml/processors/MachineLearningProcessor.impl.h index 6cd16e7b301ad2523f9227254570c2c5ed148d63..d130f029c613fcb02e81da3b472f10387965f3f5 100644 --- a/Grinder/ml/processors/MachineLearningProcessor.impl.h +++ b/Grinder/ml/processors/MachineLearningProcessor.impl.h @@ -6,33 +6,66 @@ #include "Grinder.h" #include "MachineLearningProcessor.h" #include "ml/MachineLearningMethodBase.h" +#include "ml/MachineLearningTaskSpawnerBase.h" + +template<typename BlockType> +const char* MachineLearningProcessor<BlockType>::Data_Value_Method = "Method"; +template<typename BlockType> +const char* MachineLearningProcessor<BlockType>::Data_Value_State = "State"; Q_DECLARE_METATYPE(const MachineLearningMethodBase*) -template<typename MethodType> -const char* MachineLearningProcessor<MethodType>::Data_Value_Method = "Method"; -template<typename MethodType> -const char* MachineLearningProcessor<MethodType>::Data_Value_State = "State"; +template<typename BlockType> +void MachineLearningProcessor<BlockType>::execute(EngineExecutionContext& ctx) +{ + if (!this->isBlockBypassed()) + { + // Get the used machine learning method and model state + const MachineLearningMethodBase* method = this->template portData<const MachineLearningMethodBase*>(ctx, this->_block->methodPort(), Data_Value_Method); + QString state = this->template portData<QString>(ctx, this->_block->statePort(), Data_Value_State, false); + + execute(ctx, method, state); + } +} -template<typename MethodType> -MachineLearningProcessor<MethodType>::MachineLearningProcessor(const Block* block) : Processor<MachineLearningBlock<MethodType>>(block) +template<typename BlockType> +std::shared_ptr<Task> MachineLearningProcessor<BlockType>::spawnTask(SpawnType type, const MachineLearningMethodBase* method, QString state) const { + if (auto spawner = method->createTaskSpawner()) + { + QString taskName = QString{"%1 %2 (%3)"}.arg(method->getMethodName()).arg(getSpawnTypeName(type)).arg(this->_block->getFormattedName()); + std::shared_ptr<Task> task; + + switch (type) + { + case SpawnType::Training: + task = spawner->spawnTrainingTask(state, taskName); + break; + + case SpawnType::Inference: + task = spawner->spawnInferenceTask(state, taskName); + break; + } + + return task; + } + else + this->throwProcessorException("Unable to create a task spawner"); + return nullptr; } -template<typename MethodType> -void MachineLearningProcessor<MethodType>::execute(EngineExecutionContext& ctx) +template<typename BlockType> +QString MachineLearningProcessor<BlockType>::getSpawnTypeName(SpawnType type) const { - // Only set persistent data on the first image - if (ctx.isFirstImage()) + switch (type) { - // Verify all configuration settings - this->_block->method().config().verifyConfiguration(); + case SpawnType::Training: + return "Training"; - // Store the machine learning method in the context persistent data - ctx.persistentData().set(this->_block->methodPort(), Data_Value_Method, static_cast<const MachineLearningMethodBase*>(&this->_block->method())); + case SpawnType::Inference: + return "Inference"; } - // Store the model state in the context data - ctx.data().set(this->_block->statePort(), Data_Value_State, this->_block->state()->getValue()); + return ""; } diff --git a/Grinder/ml/processors/TrainingProcessor.cpp b/Grinder/ml/processors/TrainingProcessor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dd685ed171e3ddacfc01fbde2a3a27465d2a4cf5 --- /dev/null +++ b/Grinder/ml/processors/TrainingProcessor.cpp @@ -0,0 +1,29 @@ +/****************************************************************************** + * File: TrainingProcessor.cpp + * Date: 13.8.2019 + *****************************************************************************/ + +#include "Grinder.h" +#include "TrainingProcessor.h" +#include "ml/MachineLearningTaskSpawnerBase.h" + +TrainingProcessor::TrainingProcessor(const Block* block) : MachineLearningProcessor(block) +{ + +} + +void TrainingProcessor::execute(EngineExecutionContext& ctx, const MachineLearningMethodBase* method, QString state) +{ + // Training is only executed in batch mode + if (ctx.hasExecutionFlag(Engine::ExecutionFlag::Batch)) + { + // Spawn the training task when the last image is active + if (ctx.isLastImage()) + { + // TODO: Gather images, labels etc. before spawning + spawnTask(SpawnType::Training, method, state); + } + } + else + throwProcessorException("Training is only possible in batch mode; bypass the training block to avoid this warning"); +} diff --git a/Grinder/ml/processors/TrainingProcessor.h b/Grinder/ml/processors/TrainingProcessor.h new file mode 100644 index 0000000000000000000000000000000000000000..5dde33bf14840cc412e34b3a8c09190a0098c7fb --- /dev/null +++ b/Grinder/ml/processors/TrainingProcessor.h @@ -0,0 +1,24 @@ +/****************************************************************************** + * File: TrainingProcessor.h + * Date: 13.8.2019 + *****************************************************************************/ + +#ifndef TRAININGPROCESSOR_H +#define TRAININGPROCESSOR_H + +#include "MachineLearningProcessor.h" +#include "ml/blocks/TrainingBlock.h" + +namespace grndr +{ + class TrainingProcessor : public MachineLearningProcessor<TrainingBlock> + { + public: + TrainingProcessor(const Block* block); + + protected: + virtual void execute(EngineExecutionContext& ctx, const MachineLearningMethodBase* method, QString state) override; + }; +} + +#endif diff --git a/Grinder/pipeline/BlockCatalog.cpp b/Grinder/pipeline/BlockCatalog.cpp index 917125f097e1808daadbd4ab185017f6e3882053..f5cc6e26a22b0d9d5084da98ae03de8f6c6463ba 100644 --- a/Grinder/pipeline/BlockCatalog.cpp +++ b/Grinder/pipeline/BlockCatalog.cpp @@ -33,6 +33,7 @@ #include "blocks/ResizeBlock.h" #include "blocks/SaveImageBlock.h" +#include "ml/blocks/TrainingBlock.h" #include "ml/barista/blocks/BaristaClassifierBlock.h" #define REGISTER_BLOCK_TYPE(cls, desc) registerBlockType(cls::type_value, cls::category_value, [](Pipeline* pipeline, QString name) { return std::make_unique<cls>(pipeline, name); }, desc) @@ -147,5 +148,6 @@ void BlockCatalog::registerStandardBlocks() REGISTER_BLOCK_TYPE(ResizeBlock, "Resizes its input."); REGISTER_BLOCK_TYPE(SaveImageBlock, "Saves its input to an image file."); + REGISTER_BLOCK_TYPE(TrainingBlock, "Performs training on labeled images using a machine learning method."); REGISTER_BLOCK_TYPE(BaristaClassifierBlock, "Provides the Barista machine learning classifier."); } diff --git a/Grinder/pipeline/BlockType.cpp b/Grinder/pipeline/BlockType.cpp index 1a757904c60a3c4b636f3ed052d6a333d9b552cb..0c357a55f20609c8ba38c109f33d643ce352873a 100644 --- a/Grinder/pipeline/BlockType.cpp +++ b/Grinder/pipeline/BlockType.cpp @@ -43,6 +43,8 @@ const char* BlockType::Edges = "Edges"; const char* BlockType::Watershed = "Watershed"; const char* BlockType::GrabCut = "GrabCut"; +const char* BlockType::Training = "Training"; +const char* BlockType::Inference = "Inference"; const char* BlockType::BaristaClassifier = "BaristaClassifier"; const char* BlockType::ImageTags = "ImageTags"; diff --git a/Grinder/pipeline/BlockType.h b/Grinder/pipeline/BlockType.h index 7d802661a0ac6575fed5f8ecda5c8172e564ec74..c1a423e8e45088ac3f2731101d3c621eb333985c 100644 --- a/Grinder/pipeline/BlockType.h +++ b/Grinder/pipeline/BlockType.h @@ -50,6 +50,8 @@ namespace grndr static const char* Watershed; static const char* GrabCut; + static const char* Training; + static const char* Inference; static const char* BaristaClassifier; static const char* ImageTags; diff --git a/Grinder/pipeline/PortType.cpp b/Grinder/pipeline/PortType.cpp index 11cbbd2fc4b447163f92cc7dea3e8e19ea7ea465..880a6a5e1a7b19b58b0fe406af44b3fd50eead7b 100644 --- a/Grinder/pipeline/PortType.cpp +++ b/Grinder/pipeline/PortType.cpp @@ -30,3 +30,4 @@ const char* PortType::Successor = "Successor"; const char* PortType::Method = "Method"; const char* PortType::State = "State"; +const char* PortType::TrainedState = "TrainedState"; diff --git a/Grinder/pipeline/PortType.h b/Grinder/pipeline/PortType.h index 72101b969c9d3af4d29cb4bfddecfe2192000772..07c92f3fcbadfca8fe6df7f0a3ffd60c596b44a4 100644 --- a/Grinder/pipeline/PortType.h +++ b/Grinder/pipeline/PortType.h @@ -37,6 +37,7 @@ namespace grndr static const char* Method; static const char* State; + static const char* TrainedState; public: using QString::QString;