diff --git a/Grinder/Grinder.pro b/Grinder/Grinder.pro index 8b2bdfd1e0c97413a83ad57479efaaeb4b82f006..4b70225606822ea613cb56944db87e194226f08a 100644 --- a/Grinder/Grinder.pro +++ b/Grinder/Grinder.pro @@ -456,7 +456,8 @@ SOURCES += \ ml/blocks/MachineLearningBlock.cpp \ ml/blocks/TrainingBlock.cpp \ ml/processors/TrainingProcessor.cpp \ - project/exporters/HDF5File.cpp + project/exporters/HDF5Export.cpp \ + ml/tasks/MachineLearningTask.cpp HEADERS += \ ui/mainwnd/GrinderWindow.h \ @@ -986,7 +987,9 @@ HEADERS += \ ml/processors/MachineLearningProcessor.h \ ml/processors/MachineLearningProcessor.impl.h \ ml/barista/BaristaClassifierTaskSpawner.impl.h \ - project/exporters/HDF5File.h + project/exporters/HDF5Export.h \ + pipeline/tasks/PipelineTask.h \ + ml/tasks/MachineLearningTask.h FORMS += \ ui/mainwnd/GrinderWindow.ui \ diff --git a/Grinder/Version.h b/Grinder/Version.h index 09743bf4a67b6c9682f3510bdcabca6fc1c43070..12a1b6ac393f5527d1b8a2fd1c0cc01e2a7051ab 100644 --- a/Grinder/Version.h +++ b/Grinder/Version.h @@ -10,14 +10,14 @@ #define GRNDR_INFO_TITLE "Grinder" #define GRNDR_INFO_COPYRIGHT "Copyright (c) WWU Muenster" -#define GRNDR_INFO_DATE "26.8.2019" +#define GRNDR_INFO_DATE "27.8.2019" #define GRNDR_INFO_COMPANY "WWU Muenster" #define GRNDR_INFO_WEBSITE "http://www.uni-muenster.de" #define GRNDR_VERSION_MAJOR 0 #define GRNDR_VERSION_MINOR 15 #define GRNDR_VERSION_REVISION 0 -#define GRNDR_VERSION_BUILD 376 +#define GRNDR_VERSION_BUILD 380 namespace grndr { diff --git a/Grinder/controller/TaskController.h b/Grinder/controller/TaskController.h index f733ac88ed069211150490a01f39c3743c9c1ee2..6d81dea9015742190fee8d3328bc9dc850b72283 100644 --- a/Grinder/controller/TaskController.h +++ b/Grinder/controller/TaskController.h @@ -7,12 +7,11 @@ #define TASKCONTROLLER_H #include "GenericController.h" -#include "task/TaskType.h" +#include "task/Task.h" namespace grndr { class TaskPool; - class Task; class TaskPoolWidget; class TaskController : public GenericController diff --git a/Grinder/engine/Engine.cpp b/Grinder/engine/Engine.cpp index 48667c557834e460cfc0b608d0688cec217f1d00..4198f26c44b5cb82a98253060a9f3132730e4c7a 100644 --- a/Grinder/engine/Engine.cpp +++ b/Grinder/engine/Engine.cpp @@ -34,7 +34,7 @@ cv::Mat Engine::executeLabelEx(Label* label, const Port* requestPort, std::vecto blockCount += level.size(); LongOperation opProcessImages{"Processing images", static_cast<unsigned int>(imageReferences.size())}; - EngineExecutionContext ctx{this, label, mode, flags}; + EngineExecutionContext ctx{this, label, imageReferences, mode, flags}; for (unsigned int i = 0; i < imageReferences.size(); ++i) { @@ -43,16 +43,7 @@ cv::Mat Engine::executeLabelEx(Label* label, const Port* requestPort, std::vecto opProcessImages.setStatusMessage(imageRef->getImageFileName()); LongOperation opProcessPipeline{"Processing pipeline blocks", blockCount}; - // Check if the image is the first and/or last one in the execution - EngineExecutionContext::ImagePositions imagePositions = EngineExecutionContext::ImagePosition::None; - - if (i == 0) - imagePositions.setFlag(EngineExecutionContext::ImagePosition::First); - - if (i == imageReferences.size() - 1) - imagePositions.setFlag(EngineExecutionContext::ImagePosition::Last); - - ctx.begin(imageRef, imagePositions); + ctx.begin(imageRef, i); for (unsigned int currentLevel = 0; currentLevel < blockHierarchy.size(); ++currentLevel) { diff --git a/Grinder/engine/EngineExecutionContext.cpp b/Grinder/engine/EngineExecutionContext.cpp index 210e4d5bfe404f14610a7395c0884a34e3bb059f..b0e0b17d05f2f58fef8f595c13d083eed7409c28 100644 --- a/Grinder/engine/EngineExecutionContext.cpp +++ b/Grinder/engine/EngineExecutionContext.cpp @@ -8,8 +8,8 @@ #include "pipeline/Block.h" #include "pipeline/Port.h" -EngineExecutionContext::EngineExecutionContext(const Engine* engine, Label* label, Engine::ExecutionMode mode, Engine::ExecutionFlags flags) : - _engine{engine}, _label{label}, _executionMode{mode}, _executionFlags{flags} +EngineExecutionContext::EngineExecutionContext(const Engine* engine, Label* label, const std::vector<const ImageReference*>& imageReferences, Engine::ExecutionMode mode, Engine::ExecutionFlags flags) : + _engine{engine}, _label{label}, _imageReferences{imageReferences}, _executionMode{mode}, _executionFlags{flags} { if (!engine) throw std::invalid_argument{_EXCPT("engine may not be null")}; @@ -18,19 +18,19 @@ EngineExecutionContext::EngineExecutionContext(const Engine* engine, Label* labe throw std::invalid_argument{_EXCPT("label may not be null")}; } -void EngineExecutionContext::begin(const ImageReference* activeImageReference, ImagePositions imagePositions) +void EngineExecutionContext::begin(const ImageReference* activeImageReference, unsigned int imageIndex) { if (!activeImageReference) throw std::invalid_argument{_EXCPT("activeImageReference may not be null")}; _activeImageReference = activeImageReference; - _activeImagePositions = imagePositions; + _activeImageIndex = imageIndex; } void EngineExecutionContext::end() { _activeImageReference = nullptr; - _activeImagePositions = ImagePosition::None; + _activeImageIndex = 0; // Clear previous execution data _executionData.clear(); diff --git a/Grinder/engine/EngineExecutionContext.h b/Grinder/engine/EngineExecutionContext.h index 0c7ad42fbb3e40fcc8e19ad3d42b74c8e8783e85..653895461a2f23b350f35f47df72cca3d4229e63 100644 --- a/Grinder/engine/EngineExecutionContext.h +++ b/Grinder/engine/EngineExecutionContext.h @@ -23,21 +23,10 @@ namespace grndr class EngineExecutionContext { public: - enum class ImagePosition : unsigned int - { - None = 0x00, - - First = 0x01, - Last = 0x02, - }; - - Q_DECLARE_FLAGS(ImagePositions, ImagePosition) - - public: - EngineExecutionContext(const Engine* engine, Label* label, Engine::ExecutionMode mode, Engine::ExecutionFlags flags); + EngineExecutionContext(const Engine* engine, Label* label, const std::vector<const ImageReference*>& imageReferences, Engine::ExecutionMode mode, Engine::ExecutionFlags flags); public: - void begin(const ImageReference* activeImageReference, ImagePositions imagePositions); + void begin(const ImageReference* activeImageReference, unsigned int imageIndex); void end(); void purge(const BlockHierarchy& blockHierarchy, unsigned int reachedLevel); @@ -47,9 +36,11 @@ namespace grndr Label* label() { return _label; } const Label* label() const { return _label; } + const std::vector<const ImageReference*>& imageReferences() const { return _imageReferences; } const ImageReference* activeImageReference() const { return _activeImageReference; } - bool isFirstImage() const { return _activeImagePositions.testFlag(ImagePosition::First); } - bool isLastImage() const { return _activeImagePositions.testFlag(ImagePosition::Last); } + unsigned int getActiveImageIndex() const { return _activeImageIndex; } + bool isFirstImage() const { return _activeImageReference && _activeImageIndex == 0; } + bool isLastImage() const { return _activeImageReference && _activeImageIndex == _imageReferences.size() - 1; } EngineExecutionData& data() { return _executionData; } const EngineExecutionData& data() const { return _executionData; } @@ -71,8 +62,9 @@ namespace grndr const Engine* _engine{nullptr}; Label* _label{nullptr}; + const std::vector<const ImageReference*>& _imageReferences; const ImageReference* _activeImageReference{nullptr}; - ImagePositions _activeImagePositions{ImagePosition::None}; + unsigned int _activeImageIndex{0}; EngineExecutionData _executionData; EngineExecutionData _persistentData; @@ -80,10 +72,7 @@ namespace grndr Engine::ExecutionMode _executionMode{Engine::ExecutionMode::Execute}; Engine::ExecutionFlags _executionFlags{Engine::ExecutionFlag::None}; bool _abortProcessing{false}; - }; } -Q_DECLARE_OPERATORS_FOR_FLAGS(grndr::EngineExecutionContext::ImagePositions) - #endif diff --git a/Grinder/engine/EngineExecutionData.cpp b/Grinder/engine/EngineExecutionData.cpp index 9f734cdab4a90665d3abb52e2544d2c0c76995dd..cfe3a1977fd10f886dc975474e64d4671c9c4432 100644 --- a/Grinder/engine/EngineExecutionData.cpp +++ b/Grinder/engine/EngineExecutionData.cpp @@ -5,41 +5,28 @@ #include "Grinder.h" #include "EngineExecutionData.h" -#include "pipeline/Port.h" -DataBlob* EngineExecutionData::get(const Port* port) +DataBlob* EngineExecutionData::get(const void* obj) { - if (!port->isOut()) - throw std::invalid_argument{_EXCPT("Data can only be retrieved for out-ports")}; - - if (contains(port)) - return &_data.at(port); + if (contains(obj)) + return &_data.at(obj); else return nullptr; } -void EngineExecutionData::set(const Port* port, const DataBlob& data) +void EngineExecutionData::set(const void* obj, const DataBlob& data) { - if (!port->isOut()) - throw std::invalid_argument{_EXCPT("Data can only be set for out-ports")}; - - _data.emplace(port, data); + _data.emplace(obj, data); } -void EngineExecutionData::set(const Port* port, DataBlob&& data) +void EngineExecutionData::set(const void* obj, DataBlob&& data) { - if (!port->isOut()) - throw std::invalid_argument{_EXCPT("Data can only be set for out-ports")}; - - _data.emplace(port, std::move(data)); + _data.emplace(obj, std::move(data)); } -QVariant EngineExecutionData::get(const Port* port, QString name) const +QVariant EngineExecutionData::get(const void* obj, QString name) const { - if (!port->isOut()) - throw std::invalid_argument{_EXCPT("Data can only be retrieved for out-ports")}; - - auto it = _values.find(port); + auto it = _values.find(obj); if (it != _values.cend()) { @@ -50,25 +37,19 @@ QVariant EngineExecutionData::get(const Port* port, QString name) const return QVariant{}; } -void EngineExecutionData::set(const Port* port, QString name, const QVariant& data) +void EngineExecutionData::set(const void* obj, QString name, const QVariant& data) { - if (!port->isOut()) - throw std::invalid_argument{_EXCPT("Data can only be set for out-ports")}; - - _values[port][name] = data; + _values[obj][name] = data; } -void EngineExecutionData::set(const Port* port, QString name, QVariant&& data) +void EngineExecutionData::set(const void* obj, QString name, QVariant&& data) { - if (!port->isOut()) - throw std::invalid_argument{_EXCPT("Data can only be set for out-ports")}; - - _values[port][name] = std::move(data); + _values[obj][name] = std::move(data); } -bool EngineExecutionData::contains(const Port* port, QString name) const +bool EngineExecutionData::contains(const void* obj, QString name) const { - auto it = _values.find(port); + auto it = _values.find(obj); if (it != _values.cend()) return it->second.find(name) != it->second.cend(); @@ -76,14 +57,14 @@ bool EngineExecutionData::contains(const Port* port, QString name) const return false; } -void EngineExecutionData::remove(const Port* port) +void EngineExecutionData::remove(const void* obj) { - _data.erase(port); + _data.erase(obj); } -void EngineExecutionData::remove(const Port* port, QString name) +void EngineExecutionData::remove(const void* obj, QString name) { - auto it = _values.find(port); + auto it = _values.find(obj); if (it != _values.cend()) it->second.erase(name); diff --git a/Grinder/engine/EngineExecutionData.h b/Grinder/engine/EngineExecutionData.h index 7d8dede5f0a1fbd32910ab42ef46a32b5ae43ef4..d255962d22e58aa7fa1e95442d477762b34ca490 100644 --- a/Grinder/engine/EngineExecutionData.h +++ b/Grinder/engine/EngineExecutionData.h @@ -11,36 +11,37 @@ namespace grndr { class Port; + class Block; class EngineExecutionData { private: - using PortValues = std::map<QString, QVariant>; + using Values = std::map<QString, QVariant>; public: - DataBlob* get(const Port* port); - void set(const Port* port, const DataBlob& data); - void set(const Port* port, DataBlob&& data); + DataBlob* get(const void* obj); + void set(const void* obj, const DataBlob& data); + void set(const void* obj, DataBlob&& data); - QVariant get(const Port* port, QString name) const; + QVariant get(const void* obj, QString name) const; template<typename ValueType> - ValueType get(const Port* port, QString name) const { return get(port, name).value<ValueType>(); } - void set(const Port* port, QString name, const QVariant& data); + ValueType get(const void* obj, QString name) const { return get(obj, name).value<ValueType>(); } + void set(const void* obj, QString name, const QVariant& data); template<typename ValueType> - void set(const Port* port, QString name, const ValueType& data) { set(port, name, QVariant::fromValue<ValueType>(data)); } - void set(const Port* port, QString name, QVariant&& data); + void set(const void* obj, QString name, const ValueType& data) { set(obj, name, QVariant::fromValue<ValueType>(data)); } + void set(const void* obj, QString name, QVariant&& data); - bool contains(const Port* port) const { return _data.find(port) != _data.end(); } - bool contains(const Port* port, QString name) const; + bool contains(const void* obj) const { return _data.find(obj) != _data.end(); } + bool contains(const void* obj, QString name) const; - void remove(const Port* port); - void remove(const Port* port, QString name); + void remove(const void* obj); + void remove(const void* obj, QString name); void clear(); private: - std::map<const Port*, DataBlob> _data; - std::map<const Port*, PortValues> _values; + std::map<const void*, DataBlob> _data; + std::map<const void*, Values> _values; }; } diff --git a/Grinder/engine/data/DataBlob.cpp b/Grinder/engine/data/DataBlob.cpp index 90c9316280f7d17b42fcfa659318dac78e16a73e..5711a23eaa105344faba274c2b46c022b777fb9f 100644 --- a/Grinder/engine/data/DataBlob.cpp +++ b/Grinder/engine/data/DataBlob.cpp @@ -1,124 +1,145 @@ -/****************************************************************************** - * File: DataBlob.cpp - * Date: 19.2.2018 - *****************************************************************************/ - -#include "Grinder.h" -#include "DataBlob.h" -#include "DataExceptions.h" -#include "cv/ColorConv.h" - -#include <opencv2/imgproc.hpp> - -DataBlob::DataBlob(const DataDescriptor& dataDesc, ColorSpace colorSpace) : - _dataDescriptor{dataDesc}, _colorSpace{colorSpace} -{ - if (dataDesc.isArbitrary()) - throw DataException{_EXCPT("dataDesc may not be arbitrary")}; -} - -void DataBlob::set(const cv::Mat& data) -{ - if (!data.empty()) - data.copyTo(_data); - else - clear(); - - updateDataDescriptor(); -} - -void DataBlob::set(cv::Mat&& data) -{ - _data = std::move(data); - updateDataDescriptor(); -} - -void DataBlob::convertTo(const DataDescriptor& dataDesc, bool normalize) -{ - if (!_dataDescriptor.canConvertTo(dataDesc) || dataDesc.isDynamic()) - throw DataException{_EXCPT("Invalid conversion")}; - - DataDescriptor dataDescNew = _dataDescriptor; - - int channels, depth; - dataDesc.getCVMatrixType(&channels, &depth); - - try { - // First, convert image colors count if necessary - if (dataDesc.getFieldType() != DataDescriptor::FieldType::Any && channels != _data.channels()) - { - // Color conversion can only be carried out on 8- or 16-bit unsigned or floating-point images - auto depth = _data.depth(); - - if (depth != CV_8U && depth != CV_16U && depth != CV_32F) - _data.convertTo(_data, CV_32F); - - // Check if a color <-> grayscale conversion can be done - if (_dataDescriptor.canConvertToColor(dataDesc)) - { - // Ensure that we're in RGB color space before going from grayscale to RGB - setColorSpace(ColorSpace::RGB); - - cv::cvtColor(_data, _data, cv::COLOR_GRAY2BGR); - - // Update the new data descriptor to match the new field type - dataDescNew = DataDescriptor{dataDescNew.getName(), dataDescNew.getStructureType(), DataDescriptor::FieldType::Color, dataDescNew.getValueType()}; - } - else if (_dataDescriptor.canConvertToGrayscale(dataDesc)) - { - // Ensure that we're in RGB color space before going to grayscale - setColorSpace(ColorSpace::RGB); - - cv::cvtColor(_data, _data, cv::COLOR_BGR2GRAY); - - // Update the new data descriptor to match the new field type - dataDescNew = DataDescriptor{dataDescNew.getName(), dataDescNew.getStructureType(), DataDescriptor::FieldType::Basic, dataDescNew.getValueType()}; - } - - // Convert the data back to the original depth - if (_data.depth() != depth) - _data.convertTo(_data, depth); - } - - // Next, convert the value type if necessary - if (dataDesc.getValueType() != DataDescriptor::ValueType::Any && depth != _data.depth()) - { - if (normalize && dataDesc.getValueType() < _dataDescriptor.getValueType()) // Normalize only if the new type is smaller than the current one - { - auto valueRange = dataDesc.getValueRange(); - cv::normalize(_data, _data, valueRange.first, valueRange.second, cv::NORM_MINMAX); - } - - _data.convertTo(_data, depth); - - // Update the new data descriptor to match the new value type - dataDescNew = DataDescriptor{dataDescNew.getName(), dataDescNew.getStructureType(), dataDescNew.getFieldType(), dataDesc.getValueType()}; - } - - // Set the new data descriptor to match the converted type - _dataDescriptor = dataDescNew; - } catch (std::exception& e) { - // Forward exceptions from OpenCV as a DataException - throw DataException{_EXCPT(e.what())}; - } -} - -void DataBlob::setColorSpace(ColorSpace colorSpace) -{ - if (colorSpace != _colorSpace) - { - try { - ColorConv::convertColorSpace(_data, _colorSpace, colorSpace); - } catch (std::exception& e) { - // Forward exceptions from OpenCV as a DataException - throw DataException{_EXCPT(e.what())}; - } - - _colorSpace = colorSpace; - } -} - -void DataBlob::updateDataDescriptor() -{ - _dataDescriptor.fromCVMatrixType(_data); -} +/****************************************************************************** + * File: DataBlob.cpp + * Date: 19.2.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "DataBlob.h" +#include "DataExceptions.h" +#include "cv/ColorConv.h" + +#include <opencv2/imgproc.hpp> + +DataBlob::DataBlob(const DataDescriptor& dataDesc, const std::vector<MetaData>& metaData, ColorSpace colorSpace) : + _dataDescriptor{dataDesc}, _colorSpace{colorSpace} +{ + if (dataDesc.isArbitrary()) + throw DataException{_EXCPT("dataDesc may not be arbitrary")}; + + mergeMetaData(metaData); +} + +void DataBlob::set(const cv::Mat& data) +{ + if (!data.empty()) + data.copyTo(_data); + else + clear(); + + updateDataDescriptor(); +} + +void DataBlob::set(cv::Mat&& data) +{ + _data = std::move(data); + updateDataDescriptor(); +} + +void DataBlob::convertTo(const DataDescriptor& dataDesc, bool normalize) +{ + if (!_dataDescriptor.canConvertTo(dataDesc) || dataDesc.isDynamic()) + throw DataException{_EXCPT("Invalid conversion")}; + + DataDescriptor dataDescNew = _dataDescriptor; + + int channels, depth; + dataDesc.getCVMatrixType(&channels, &depth); + + try { + // First, convert image colors count if necessary + if (dataDesc.getFieldType() != DataDescriptor::FieldType::Any && channels != _data.channels()) + { + // Color conversion can only be carried out on 8- or 16-bit unsigned or floating-point images + auto depth = _data.depth(); + + if (depth != CV_8U && depth != CV_16U && depth != CV_32F) + _data.convertTo(_data, CV_32F); + + // Check if a color <-> grayscale conversion can be done + if (_dataDescriptor.canConvertToColor(dataDesc)) + { + // Ensure that we're in RGB color space before going from grayscale to RGB + setColorSpace(ColorSpace::RGB); + + cv::cvtColor(_data, _data, cv::COLOR_GRAY2BGR); + + // Update the new data descriptor to match the new field type + dataDescNew = DataDescriptor{dataDescNew.getName(), dataDescNew.getStructureType(), DataDescriptor::FieldType::Color, dataDescNew.getValueType()}; + } + else if (_dataDescriptor.canConvertToGrayscale(dataDesc)) + { + // Ensure that we're in RGB color space before going to grayscale + setColorSpace(ColorSpace::RGB); + + cv::cvtColor(_data, _data, cv::COLOR_BGR2GRAY); + + // Update the new data descriptor to match the new field type + dataDescNew = DataDescriptor{dataDescNew.getName(), dataDescNew.getStructureType(), DataDescriptor::FieldType::Basic, dataDescNew.getValueType()}; + } + + // Convert the data back to the original depth + if (_data.depth() != depth) + _data.convertTo(_data, depth); + } + + // Next, convert the value type if necessary + if (dataDesc.getValueType() != DataDescriptor::ValueType::Any && depth != _data.depth()) + { + if (normalize && dataDesc.getValueType() < _dataDescriptor.getValueType()) // Normalize only if the new type is smaller than the current one + { + auto valueRange = dataDesc.getValueRange(); + cv::normalize(_data, _data, valueRange.first, valueRange.second, cv::NORM_MINMAX); + } + + _data.convertTo(_data, depth); + + // Update the new data descriptor to match the new value type + dataDescNew = DataDescriptor{dataDescNew.getName(), dataDescNew.getStructureType(), dataDescNew.getFieldType(), dataDesc.getValueType()}; + } + + // Set the new data descriptor to match the converted type + _dataDescriptor = dataDescNew; + } catch (std::exception& e) { + // Forward exceptions from OpenCV as a DataException + throw DataException{_EXCPT(e.what())}; + } +} + +void DataBlob::setColorSpace(ColorSpace colorSpace) +{ + if (colorSpace != _colorSpace) + { + try { + ColorConv::convertColorSpace(_data, _colorSpace, colorSpace); + } catch (std::exception& e) { + // Forward exceptions from OpenCV as a DataException + throw DataException{_EXCPT(e.what())}; + } + + _colorSpace = colorSpace; + } +} + +void DataBlob::mergeMetaData(const std::vector<MetaData>& metaData) +{ + if (!metaData.empty()) + { + for (auto data : metaData) + { + if (_metaData.isEmpty()) + { + _metaData = data; + } + else + { + for (auto key : data.keys()) + _metaData[key] = data[key]; + } + } + } +} + +void DataBlob::updateDataDescriptor() +{ + _dataDescriptor.fromCVMatrixType(_data); +} diff --git a/Grinder/engine/data/DataBlob.h b/Grinder/engine/data/DataBlob.h index 3cf7bdd3d7349a0221d30ad7c30ef7f5bff67763..853d900fd95fa6c2eb812968e4d00e00d3892c82 100644 --- a/Grinder/engine/data/DataBlob.h +++ b/Grinder/engine/data/DataBlob.h @@ -1,85 +1,97 @@ -/****************************************************************************** - * File: DataBlob.h - * Date: 19.2.2018 - *****************************************************************************/ - -#ifndef DATABLOB_H -#define DATABLOB_H - -#include "DataDescriptor.h" -#include "cv/CVTypes.h" - -#include <opencv2/core.hpp> - -namespace grndr -{ - class DataBlob - { - public: - DataBlob(const DataDescriptor& dataDesc, ColorSpace colorSpace = ColorSpace::RGB); - template<typename DataType> - DataBlob(const DataDescriptor& dataDesc, const DataType& data, ColorSpace colorSpace = ColorSpace::RGB); - template<typename DataType> - DataBlob(const DataDescriptor& dataDesc, DataType&& data, ColorSpace colorSpace = ColorSpace::RGB); - DataBlob(const DataBlob& blob) = default; - DataBlob(DataBlob&& blob) = default; - - DataBlob& operator =(const DataBlob& blob) = default; - DataBlob& operator =(DataBlob&& blob) = default; - - DataBlob& operator =(const cv::Mat& data) { set(data); return *this; } - template<typename DataType> - DataBlob& operator =(const std::vector<DataType>& data) { set(data); return *this; } - template<typename DataType> - DataBlob& operator =(const DataType& data) { set(data); return *this; } - - operator cv::Mat() const { return getMatrix(); } - template<typename DataType> - operator std::vector<DataType>() const { return getVector<DataType>(); } - template<typename DataType> - operator DataType() const { return getScalar<DataType>(); } - - public: - void set(const cv::Mat& data); - void set(cv::Mat&& data); - template<typename DataType> - void set(const std::vector<DataType>& data); - template<typename DataType> - void set(const DataType& data); - - cv::Mat getMatrix() const { return _data; } - template<typename DataType> - std::vector<DataType> getVector() const; - template<typename DataType> - DataType getScalar() const; - - void convertTo(const DataDescriptor& dataDesc, bool normalize = false); - template<typename TargetType> - void convertTo(const DataDescriptor& dataDesc); - - void clear() { _data.release(); } - - public: - const DataDescriptor& dataDescriptor() { return _dataDescriptor; } - ColorSpace getColorSpace() const { return _colorSpace; } - void setColorSpace(ColorSpace colorSpace); - - cv::Mat& data() { return _data; } - const cv::Mat& data() const { return _data; } - - bool empty() const { return _data.empty(); } - - private: - void updateDataDescriptor(); - - private: - DataDescriptor _dataDescriptor; - ColorSpace _colorSpace{ColorSpace::RGB}; - - cv::Mat _data; - }; -} - -#include "DataBlob.impl.h" - -#endif +/****************************************************************************** + * File: DataBlob.h + * Date: 19.2.2018 + *****************************************************************************/ + +#ifndef DATABLOB_H +#define DATABLOB_H + +#include "DataDescriptor.h" +#include "cv/CVTypes.h" + +#include <opencv2/core.hpp> + +namespace grndr +{ + class DataBlob + { + public: + using MetaData = QVariantMap; + + public: + DataBlob(const DataDescriptor& dataDesc, const std::vector<MetaData>& metaData = {}, ColorSpace colorSpace = ColorSpace::RGB); + template<typename DataType> + DataBlob(const DataDescriptor& dataDesc, const DataType& data, const std::vector<MetaData>& metaData = {}, ColorSpace colorSpace = ColorSpace::RGB); + template<typename DataType> + DataBlob(const DataDescriptor& dataDesc, DataType&& data, const std::vector<MetaData>& metaData = {}, ColorSpace colorSpace = ColorSpace::RGB); + DataBlob(const DataBlob& blob) = default; + DataBlob(DataBlob&& blob) = default; + + DataBlob& operator =(const DataBlob& blob) = default; + DataBlob& operator =(DataBlob&& blob) = default; + + DataBlob& operator =(const cv::Mat& data) { set(data); return *this; } + template<typename DataType> + DataBlob& operator =(const std::vector<DataType>& data) { set(data); return *this; } + template<typename DataType> + DataBlob& operator =(const DataType& data) { set(data); return *this; } + + operator cv::Mat() const { return getMatrix(); } + template<typename DataType> + operator std::vector<DataType>() const { return getVector<DataType>(); } + template<typename DataType> + operator DataType() const { return getScalar<DataType>(); } + + public: + void set(const cv::Mat& data); + void set(cv::Mat&& data); + template<typename DataType> + void set(const std::vector<DataType>& data); + template<typename DataType> + void set(const DataType& data); + + cv::Mat getMatrix() const { return _data; } + template<typename DataType> + std::vector<DataType> getVector() const; + template<typename DataType> + DataType getScalar() const; + + void convertTo(const DataDescriptor& dataDesc, bool normalize = false); + template<typename TargetType> + void convertTo(const DataDescriptor& dataDesc); + + void clear() { _data.release(); } + + public: + const DataDescriptor& dataDescriptor() { return _dataDescriptor; } + ColorSpace getColorSpace() const { return _colorSpace; } + void setColorSpace(ColorSpace colorSpace); + + cv::Mat& data() { return _data; } + const cv::Mat& data() const { return _data; } + + bool empty() const { return _data.empty(); } + + public: + const MetaData& metaData() const {return _metaData; } + MetaData& metaData() {return _metaData; } + + void mergeMetaData(const std::vector<MetaData>& metaData); + void clearMetaData() { _metaData.clear(); } + + private: + void updateDataDescriptor(); + + private: + DataDescriptor _dataDescriptor; + ColorSpace _colorSpace{ColorSpace::RGB}; + + cv::Mat _data; + + MetaData _metaData; + }; +} + +#include "DataBlob.impl.h" + +#endif diff --git a/Grinder/engine/data/DataBlob.impl.h b/Grinder/engine/data/DataBlob.impl.h index 9f88026f287017d390229c167ff767bfa7918169..b0ebe49801560e2d43684a226bb93bd758fe0bb7 100644 --- a/Grinder/engine/data/DataBlob.impl.h +++ b/Grinder/engine/data/DataBlob.impl.h @@ -10,15 +10,17 @@ #include <cstring> template<typename DataType> -DataBlob::DataBlob(const DataDescriptor& dataDesc, const DataType& data, ColorSpace colorSpace) : DataBlob(dataDesc, colorSpace) +DataBlob::DataBlob(const DataDescriptor& dataDesc, const DataType& data, const std::vector<MetaData>& metaData, ColorSpace colorSpace) : DataBlob(dataDesc, metaData, colorSpace) { set(data); + mergeMetaData(metaData); } template<typename DataType> -DataBlob::DataBlob(const DataDescriptor& dataDesc, DataType&& data, ColorSpace colorSpace) : DataBlob(dataDesc, colorSpace) +DataBlob::DataBlob(const DataDescriptor& dataDesc, DataType&& data, const std::vector<MetaData>& metaData, ColorSpace colorSpace) : DataBlob(dataDesc, metaData, colorSpace) { set(std::move(data)); + mergeMetaData(metaData); } template<typename DataType> diff --git a/Grinder/engine/processors/AdaptiveThresholdProcessor.cpp b/Grinder/engine/processors/AdaptiveThresholdProcessor.cpp index a73c677f8fd5f2bbdbf5ae34532d9ddcb855821d..e1e765bd3fe654d886028cedc5b9e06a5eec6bb5 100644 --- a/Grinder/engine/processors/AdaptiveThresholdProcessor.cpp +++ b/Grinder/engine/processors/AdaptiveThresholdProcessor.cpp @@ -28,6 +28,6 @@ void AdaptiveThresholdProcessor::execute(EngineExecutionContext& ctx) else processedImage = getBypassedImage(dataBlob, true); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob->metaData()}, dataBlob->getColorSpace()}); } } diff --git a/Grinder/engine/processors/AlphaBlendingProcessor.cpp b/Grinder/engine/processors/AlphaBlendingProcessor.cpp index 0aa1b20da22fe8c85dca0c3bfc33fd7944a5ed01..948f4086b054647bf710f8fe0c23e97f1b6d9573 100644 --- a/Grinder/engine/processors/AlphaBlendingProcessor.cpp +++ b/Grinder/engine/processors/AlphaBlendingProcessor.cpp @@ -31,6 +31,6 @@ void AlphaBlendingProcessor::execute(EngineExecutionContext& ctx) auto alpha = _block->alpha()->getRelativeValue(); cv::addWeighted(dataBlob1->getMatrix(), alpha, dataBlob2->getMatrix(), 1.0 - alpha, 0.0, processedImage); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob1->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob1->metaData(), dataBlob2->metaData()}, dataBlob1->getColorSpace()}); } } diff --git a/Grinder/engine/processors/BinaryThresholdProcessor.cpp b/Grinder/engine/processors/BinaryThresholdProcessor.cpp index 208c14b0b7d96277742d00d1fbf3849664307e1e..22a4089df2e15a45d03f1351aced1e14df72530d 100644 --- a/Grinder/engine/processors/BinaryThresholdProcessor.cpp +++ b/Grinder/engine/processors/BinaryThresholdProcessor.cpp @@ -35,6 +35,6 @@ void BinaryThresholdProcessor::execute(EngineExecutionContext& ctx) else processedImage = getBypassedImage(dataBlob); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob->metaData()}, dataBlob->getColorSpace()}); } } diff --git a/Grinder/engine/processors/BlurProcessor.cpp b/Grinder/engine/processors/BlurProcessor.cpp index 10e621aecbe52ece326c6fd9c0335ecd4ba2292c..1bf58e3c87eb33f7189d3580cdd0e0e0a104b867 100644 --- a/Grinder/engine/processors/BlurProcessor.cpp +++ b/Grinder/engine/processors/BlurProcessor.cpp @@ -54,6 +54,6 @@ void BlurProcessor::execute(EngineExecutionContext& ctx) else processedImage = getBypassedImage(dataBlob); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob->metaData()}, dataBlob->getColorSpace()}); } } diff --git a/Grinder/engine/processors/CanvasProcessor.cpp b/Grinder/engine/processors/CanvasProcessor.cpp index 09f63be1549c1a527b714269428a004ad38095b4..9422d1b053304a3ffa852fc06b13a3e68360da00 100644 --- a/Grinder/engine/processors/CanvasProcessor.cpp +++ b/Grinder/engine/processors/CanvasProcessor.cpp @@ -5,8 +5,9 @@ #include "Grinder.h" #include "CanvasProcessor.h" -#include "project/Label.h" #include "core/GrinderApplication.h" +#include "project/Label.h" +#include "image/ImageTags.h" #include <opencv2/imgproc.hpp> #include <opencv2/highgui.hpp> @@ -48,8 +49,8 @@ void CanvasProcessor::execute(EngineExecutionContext& ctx) ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), renderedImage}); // Create the (merged) tags bitmap - ImageTagsBitmap tagsBitmap = imageBuild->renderImageTagsBitmap(); - ctx.data().set(_block->tagsBitmapPort(), DataBlob{getPortDataDescriptor(_block->tagsBitmapPort()), tagsBitmap.imageTagsBitmap().data()}); + ImageTagsBitmap tagsBitmap = imageBuild->renderImageTagsBitmap(false); + ctx.data().set(_block->tagsBitmapPort(), DataBlob{getPortDataDescriptor(_block->tagsBitmapPort()), tagsBitmap.imageTagsBitmap().data(), {dataBlob->metaData()}}); if (ctx.getExecutionMode() == Engine::ExecutionMode::View) grinder()->imageEditorManager().showEditor(_block, imageBuild); diff --git a/Grinder/engine/processors/ContoursProcessor.cpp b/Grinder/engine/processors/ContoursProcessor.cpp index bfce2bc66d74e9e3c4ec26c7f57d8ee3afaccf03..17df4e9e6352f1908dbd3466b312cb89112b33d9 100644 --- a/Grinder/engine/processors/ContoursProcessor.cpp +++ b/Grinder/engine/processors/ContoursProcessor.cpp @@ -35,6 +35,6 @@ void ContoursProcessor::execute(EngineExecutionContext& ctx) cv::drawContours(processedImage, contours, i, cv::Scalar::all(std::lround((i + 1) * colorStep)), thickness); } - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), maskBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {maskBlob->metaData()}, maskBlob->getColorSpace()}); } } diff --git a/Grinder/engine/processors/ConvertToColorProcessor.cpp b/Grinder/engine/processors/ConvertToColorProcessor.cpp index d67e4bcc1f9fba934c186ee3adda774e0c98dae4..bb8b5afffd0f0cb25692bb0b21784ad4008fcb80 100644 --- a/Grinder/engine/processors/ConvertToColorProcessor.cpp +++ b/Grinder/engine/processors/ConvertToColorProcessor.cpp @@ -26,7 +26,7 @@ void ConvertToColorProcessor::execute(EngineExecutionContext& ctx) else processedImage = getBypassedImage(dataBlob); - DataBlob dataBlobNew{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob->getColorSpace()}; + DataBlob dataBlobNew{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob->metaData()}, dataBlob->getColorSpace()}; dataBlobNew.setColorSpace(*_block->colorSpace()); ctx.data().set(_block->outPort(), std::move(dataBlobNew)); diff --git a/Grinder/engine/processors/ConvertToGrayscaleProcessor.cpp b/Grinder/engine/processors/ConvertToGrayscaleProcessor.cpp index 85aec5e47124ef1d8dd0102d04d0fc81566307ce..9d1acbb08182ffc3f0f93a887ccf16d05c683341 100644 --- a/Grinder/engine/processors/ConvertToGrayscaleProcessor.cpp +++ b/Grinder/engine/processors/ConvertToGrayscaleProcessor.cpp @@ -26,6 +26,6 @@ void ConvertToGrayscaleProcessor::execute(EngineExecutionContext& ctx) else processedImage = dataBlob->getMatrix(); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob->metaData()}, dataBlob->getColorSpace()}); } } diff --git a/Grinder/engine/processors/DilateProcessor.cpp b/Grinder/engine/processors/DilateProcessor.cpp index 2f961e82d73c99b62592832c0017e63d28892e06..38db1202fb8d95d700e7736cbf8a08407b7b0683 100644 --- a/Grinder/engine/processors/DilateProcessor.cpp +++ b/Grinder/engine/processors/DilateProcessor.cpp @@ -31,6 +31,6 @@ void DilateProcessor::execute(EngineExecutionContext& ctx) else processedImage = getBypassedImage(dataBlob); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob->metaData()}, dataBlob->getColorSpace()}); } } diff --git a/Grinder/engine/processors/DistanceTransformProcessor.cpp b/Grinder/engine/processors/DistanceTransformProcessor.cpp index 1873aadd55aff0366019afed865496755e718d1f..1c1f6806e9759adedbb47b72817d40b6b86d034f 100644 --- a/Grinder/engine/processors/DistanceTransformProcessor.cpp +++ b/Grinder/engine/processors/DistanceTransformProcessor.cpp @@ -32,6 +32,6 @@ void DistanceTransformProcessor::execute(EngineExecutionContext& ctx) else processedImage = getBypassedImage(dataBlob, true); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob->metaData()}, dataBlob->getColorSpace()}); } } diff --git a/Grinder/engine/processors/EdgesProcessor.cpp b/Grinder/engine/processors/EdgesProcessor.cpp index f1b3934fb338a2998f81bee5808d97efe2d83f79..b4e07e83f9481280053c439ef7bbe3d0f520aa55 100644 --- a/Grinder/engine/processors/EdgesProcessor.cpp +++ b/Grinder/engine/processors/EdgesProcessor.cpp @@ -27,6 +27,6 @@ void EdgesProcessor::execute(EngineExecutionContext& ctx) else processedImage = getBypassedImage(dataBlob, true); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob->metaData()}, dataBlob->getColorSpace()}); } } diff --git a/Grinder/engine/processors/EnhanceContrastProcessor.cpp b/Grinder/engine/processors/EnhanceContrastProcessor.cpp index 0d7e1579765a8d4e059a8e626428859d067fe236..424b9af8e3ea3d03e51350c8cd471fde463760f6 100644 --- a/Grinder/engine/processors/EnhanceContrastProcessor.cpp +++ b/Grinder/engine/processors/EnhanceContrastProcessor.cpp @@ -58,6 +58,6 @@ void EnhanceContrastProcessor::execute(EngineExecutionContext& ctx) else processedImage = getBypassedImage(dataBlob); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob->metaData()}, dataBlob->getColorSpace()}); } } diff --git a/Grinder/engine/processors/ErodeProcessor.cpp b/Grinder/engine/processors/ErodeProcessor.cpp index fecfc76e4cccdcc8b02d20105eac5023f22de6f9..df53307ca5e2fc51d367509f204f9d45d33de49f 100644 --- a/Grinder/engine/processors/ErodeProcessor.cpp +++ b/Grinder/engine/processors/ErodeProcessor.cpp @@ -32,6 +32,6 @@ void ErodeProcessor::execute(EngineExecutionContext& ctx) else processedImage = getBypassedImage(dataBlob); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob->metaData()}, dataBlob->getColorSpace()}); } } diff --git a/Grinder/engine/processors/GrabCutProcessor.cpp b/Grinder/engine/processors/GrabCutProcessor.cpp index abde6ea39f4656d41d080112119016b8d4d0c3d9..28a2f17ea1adf2518c4bb264cd8ba3a9b5cf3819 100644 --- a/Grinder/engine/processors/GrabCutProcessor.cpp +++ b/Grinder/engine/processors/GrabCutProcessor.cpp @@ -59,7 +59,7 @@ void GrabCutProcessor::execute(EngineExecutionContext& ctx) } cv::grabCut(dataBlob->getMatrix(), mask, cv::Rect{}, backgroundModel, foregroundModel, *_block->iterations(), cv::GC_INIT_WITH_MASK); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), generateOutputMask(mask, successors.empty()), dataBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), generateOutputMask(mask, successors.empty()), {dataBlob->metaData()}, dataBlob->getColorSpace()}); // Store the fore- and background models if it is used later by another GrabCut instance ctx.data().set(_block->successorPort(), Data_Value_BGModel, backgroundModel); @@ -68,7 +68,7 @@ void GrabCutProcessor::execute(EngineExecutionContext& ctx) else { cv::Mat processedImage = getBypassedImage(dataBlob, true); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob->metaData()}, dataBlob->getColorSpace()}); } } } diff --git a/Grinder/engine/processors/MergeChannelsProcessor.cpp b/Grinder/engine/processors/MergeChannelsProcessor.cpp index f8d12829c2983c607e5fc513163c3f15e240992d..b5710c77da6f10a82fb56ab2c3c7fd3f0ccc1a39 100644 --- a/Grinder/engine/processors/MergeChannelsProcessor.cpp +++ b/Grinder/engine/processors/MergeChannelsProcessor.cpp @@ -36,6 +36,6 @@ void MergeChannelsProcessor::execute(EngineExecutionContext& ctx) cv::Mat processedImage; cv::merge(imageChannels, 3, processedImage); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), channel1Blob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {channel1Blob->metaData(), channel2Blob->metaData(), channel3Blob->metaData()}, channel1Blob->getColorSpace()}); } } diff --git a/Grinder/engine/processors/NormalizeProcessor.cpp b/Grinder/engine/processors/NormalizeProcessor.cpp index 1eff2a1c81914f5bea5c073f2b2b7dc756289528..31c4cdc2f61501ec5242a37be4499817fd3afc21 100644 --- a/Grinder/engine/processors/NormalizeProcessor.cpp +++ b/Grinder/engine/processors/NormalizeProcessor.cpp @@ -26,6 +26,6 @@ void NormalizeProcessor::execute(EngineExecutionContext& ctx) else processedImage = getBypassedImage(dataBlob); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob->metaData()}, dataBlob->getColorSpace()}); } } diff --git a/Grinder/engine/processors/ReplaceColorProcessor.cpp b/Grinder/engine/processors/ReplaceColorProcessor.cpp index fce30a1108872aac7cedf1f96177cb4c93f6712d..d0c88f18aea8a66b59ca260e602feda3e5c4e5cd 100644 --- a/Grinder/engine/processors/ReplaceColorProcessor.cpp +++ b/Grinder/engine/processors/ReplaceColorProcessor.cpp @@ -46,6 +46,6 @@ void ReplaceColorProcessor::execute(EngineExecutionContext& ctx) else processedImage = getBypassedImage(dataBlob); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob->metaData()}, dataBlob->getColorSpace()}); } } diff --git a/Grinder/engine/processors/ResizeProcessor.cpp b/Grinder/engine/processors/ResizeProcessor.cpp index 1f275051d6b829c17adcb4e15457f4a60a44a81b..02688a67c7fbacc4043f07939a284a966d872243 100644 --- a/Grinder/engine/processors/ResizeProcessor.cpp +++ b/Grinder/engine/processors/ResizeProcessor.cpp @@ -51,6 +51,6 @@ void ResizeProcessor::execute(EngineExecutionContext& ctx) else processedImage = getBypassedImage(dataBlob); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob->metaData()}, dataBlob->getColorSpace()}); } } diff --git a/Grinder/engine/processors/SharpenProcessor.cpp b/Grinder/engine/processors/SharpenProcessor.cpp index d55c92aebc5c701ec8bf77bca771c9ea410bbab7..376333182e7f4d8412ed1841af55ffb53a780eae 100644 --- a/Grinder/engine/processors/SharpenProcessor.cpp +++ b/Grinder/engine/processors/SharpenProcessor.cpp @@ -49,6 +49,6 @@ void SharpenProcessor::execute(EngineExecutionContext& ctx) else processedImage = getBypassedImage(dataBlob); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob->metaData()}, dataBlob->getColorSpace()}); } } diff --git a/Grinder/engine/processors/SplitChannelsProcessor.cpp b/Grinder/engine/processors/SplitChannelsProcessor.cpp index 41d6228e0c8cb598677e102d2a08331a20070022..6514025bd8751866dfa38603b5323d6491bfd7ef 100644 --- a/Grinder/engine/processors/SplitChannelsProcessor.cpp +++ b/Grinder/engine/processors/SplitChannelsProcessor.cpp @@ -28,17 +28,17 @@ void SplitChannelsProcessor::execute(EngineExecutionContext& ctx) cv::Mat imageChannels[3]; cv::split(dataBlob->getMatrix(), imageChannels); - ctx.data().set(_block->channel1Port(), DataBlob{getPortDataDescriptor(_block->channel1Port()), std::move(imageChannels[0]), dataBlob->getColorSpace()}); - ctx.data().set(_block->channel2Port(), DataBlob{getPortDataDescriptor(_block->channel2Port()), std::move(imageChannels[1]), dataBlob->getColorSpace()}); - ctx.data().set(_block->channel3Port(), DataBlob{getPortDataDescriptor(_block->channel3Port()), std::move(imageChannels[2]), dataBlob->getColorSpace()}); + ctx.data().set(_block->channel1Port(), DataBlob{getPortDataDescriptor(_block->channel1Port()), std::move(imageChannels[0]), {dataBlob->metaData()}, dataBlob->getColorSpace()}); + ctx.data().set(_block->channel2Port(), DataBlob{getPortDataDescriptor(_block->channel2Port()), std::move(imageChannels[1]), {dataBlob->metaData()}, dataBlob->getColorSpace()}); + ctx.data().set(_block->channel3Port(), DataBlob{getPortDataDescriptor(_block->channel3Port()), std::move(imageChannels[2]), {dataBlob->metaData()}, dataBlob->getColorSpace()}); } else { cv::Mat processedImage = getBypassedImage(dataBlob, true); - ctx.data().set(_block->channel1Port(), DataBlob{getPortDataDescriptor(_block->channel1Port()), processedImage, dataBlob->getColorSpace()}); - ctx.data().set(_block->channel2Port(), DataBlob{getPortDataDescriptor(_block->channel2Port()), processedImage, dataBlob->getColorSpace()}); - ctx.data().set(_block->channel3Port(), DataBlob{getPortDataDescriptor(_block->channel3Port()), processedImage, dataBlob->getColorSpace()}); + ctx.data().set(_block->channel1Port(), DataBlob{getPortDataDescriptor(_block->channel1Port()), processedImage, {dataBlob->metaData()}, dataBlob->getColorSpace()}); + ctx.data().set(_block->channel2Port(), DataBlob{getPortDataDescriptor(_block->channel2Port()), processedImage, {dataBlob->metaData()}, dataBlob->getColorSpace()}); + ctx.data().set(_block->channel3Port(), DataBlob{getPortDataDescriptor(_block->channel3Port()), processedImage, {dataBlob->metaData()}, dataBlob->getColorSpace()}); } } } diff --git a/Grinder/engine/processors/WatershedProcessor.cpp b/Grinder/engine/processors/WatershedProcessor.cpp index e168821d82727f03cbd26de23ca8de349af96dcd..650aac5684040ffd5b7b78e9d44c978dca3d4868 100644 --- a/Grinder/engine/processors/WatershedProcessor.cpp +++ b/Grinder/engine/processors/WatershedProcessor.cpp @@ -60,6 +60,6 @@ void WatershedProcessor::execute(EngineExecutionContext& ctx) else processedImage = getBypassedImage(dataBlob); - ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), dataBlob->getColorSpace()}); + ctx.data().set(_block->outPort(), DataBlob{getPortDataDescriptor(_block->outPort()), std::move(processedImage), {dataBlob->metaData()}, dataBlob->getColorSpace()}); } } diff --git a/Grinder/ml/MachineLearningConfiguration.cpp b/Grinder/ml/MachineLearningConfiguration.cpp index 9552e659297afb8fff9590c8ea23ff5f7b94cfa5..0e0b7c8e6059bea0dffc9a1a974c4dacea2c8c59 100644 --- a/Grinder/ml/MachineLearningConfiguration.cpp +++ b/Grinder/ml/MachineLearningConfiguration.cpp @@ -5,3 +5,11 @@ #include "Grinder.h" #include "MachineLearningConfiguration.h" +#include "MachineLearningExceptions.h" + +void MachineLearningConfiguration::verifyConfiguration() const +{ + // Verify general settings + if (!_imageTags) + throw MachineLearningException{_EXCPT("Image tags may not be null")}; +} diff --git a/Grinder/ml/MachineLearningConfiguration.h b/Grinder/ml/MachineLearningConfiguration.h index 5f00f8fd294573e8aa3851f96ce2a8814a8a2308..89afff1a8591872940502bebb6a5f442df5102ee 100644 --- a/Grinder/ml/MachineLearningConfiguration.h +++ b/Grinder/ml/MachineLearningConfiguration.h @@ -10,12 +10,18 @@ namespace grndr { + class ImageTags; + class MachineLearningConfiguration : public QObject { Q_OBJECT public: - virtual void verifyConfiguration() const = 0; + virtual void verifyConfiguration() const; + + public: + const ImageTags* imageTags() const { return _imageTags; } + void setImageTags(const ImageTags* imageTags) { setValue(_imageTags, imageTags); } protected: template<typename ValueType> @@ -23,6 +29,9 @@ namespace grndr signals: void configurationChanged(); + + private: + const ImageTags* _imageTags{nullptr}; }; } diff --git a/Grinder/ml/MachineLearningMethodBase.h b/Grinder/ml/MachineLearningMethodBase.h index 6302def716a5808198eb795bd942b4c2a0bd2938..bca52f6915296a5722f993528c6d6f23426e6949 100644 --- a/Grinder/ml/MachineLearningMethodBase.h +++ b/Grinder/ml/MachineLearningMethodBase.h @@ -12,6 +12,7 @@ namespace grndr { class MachineLearningTaskSpawnerBase; + class MachineLearningConfiguration; class MachineLearningMethodBase : public QObject { @@ -23,6 +24,10 @@ namespace grndr public: virtual std::unique_ptr<MachineLearningTaskSpawnerBase> createTaskSpawner() const = 0; + public: + virtual MachineLearningConfiguration& config() = 0; + virtual const MachineLearningConfiguration& config() const = 0; + public: virtual QStringList getAvailableStates() const = 0; diff --git a/Grinder/ml/MachineLearningTaskSpawner.h b/Grinder/ml/MachineLearningTaskSpawner.h index b99048dba3c8523dfa585a744768b8ff34ec5b16..eff0053439b1e92ef7682a92db179a44a591d2f9 100644 --- a/Grinder/ml/MachineLearningTaskSpawner.h +++ b/Grinder/ml/MachineLearningTaskSpawner.h @@ -12,14 +12,14 @@ namespace grndr { class MachineLearningConfiguration; - class Task; + class MachineLearningTask; template<typename ConfigType, typename TrainingTaskType, typename InferenceTaskType> class MachineLearningTaskSpawner : public MachineLearningTaskSpawnerBase { static_assert(std::is_base_of<MachineLearningConfiguration, ConfigType>::value, "ConfigType must be derived from MachineLearningConfiguration"); - static_assert(std::is_base_of<Task, TrainingTaskType>::value, "TrainingTaskType must be derived from Task"); - static_assert(std::is_base_of<Task, InferenceTaskType>::value, "TrainingTaskType must be derived from Task"); + static_assert(std::is_base_of<MachineLearningTask, TrainingTaskType>::value, "TrainingTaskType must be derived from MachineLearningPipelineTask"); + static_assert(std::is_base_of<MachineLearningTask, InferenceTaskType>::value, "TrainingTaskType must be derived from MachineLearningPipelineTask"); public: using config_type = ConfigType; @@ -30,8 +30,8 @@ namespace grndr MachineLearningTaskSpawner(const config_type& config); public: - virtual std::shared_ptr<Task> spawnTrainingTask(QString state, QString name) const override; - virtual std::shared_ptr<Task> spawnInferenceTask(QString state, QString name) const override; + virtual std::shared_ptr<MachineLearningTask> spawnTrainingTask(QString state, QString name) const override; + virtual std::shared_ptr<MachineLearningTask> spawnInferenceTask(QString state, QString name) const override; protected: virtual void configureTrainingTask(training_task_type* task, QString state) const = 0; diff --git a/Grinder/ml/MachineLearningTaskSpawner.impl.h b/Grinder/ml/MachineLearningTaskSpawner.impl.h index 8a87f14d4b17df564c3f20b7a8abd2dba2ccad96..ec9865f5a08bdfa97f7ddcd40f5dcaa52f469eba 100644 --- a/Grinder/ml/MachineLearningTaskSpawner.impl.h +++ b/Grinder/ml/MachineLearningTaskSpawner.impl.h @@ -16,27 +16,27 @@ MachineLearningTaskSpawner<ConfigType, TrainingTaskType, InferenceTaskType>::Mac } template<typename ConfigType, typename TrainingTaskType, typename InferenceTaskType> -std::shared_ptr<Task> MachineLearningTaskSpawner<ConfigType, TrainingTaskType, InferenceTaskType>::spawnTrainingTask(QString state, QString name) const +std::shared_ptr<MachineLearningTask> MachineLearningTaskSpawner<ConfigType, TrainingTaskType, InferenceTaskType>::spawnTrainingTask(QString state, QString name) const { // If any of the following functions fails, an exception will be thrown auto task = createTrainingTask(name); configureTrainingTask(task.get(), state); - auto sharedTask = std::dynamic_pointer_cast<Task>(task); - grinder()->taskController().addTask(sharedTask); - return sharedTask; + auto sharedPtr = std::dynamic_pointer_cast<Task>(task); + grinder()->taskController().addTask(sharedPtr); + return std::dynamic_pointer_cast<MachineLearningTask>(task); } template<typename ConfigType, typename TrainingTaskType, typename InferenceTaskType> -std::shared_ptr<Task> MachineLearningTaskSpawner<ConfigType, TrainingTaskType, InferenceTaskType>::spawnInferenceTask(QString state, QString name) const +std::shared_ptr<MachineLearningTask> MachineLearningTaskSpawner<ConfigType, TrainingTaskType, InferenceTaskType>::spawnInferenceTask(QString state, QString name) const { // If any of the following functions fails, an exception will be thrown auto task = createInferenceTask(name); configureInferenceTask(task.get(), state); - auto sharedTask = std::dynamic_pointer_cast<Task>(task); - grinder()->taskController().addTask(sharedTask); - return sharedTask; + auto sharedPtr = std::dynamic_pointer_cast<Task>(task); + grinder()->taskController().addTask(sharedPtr); + return std::dynamic_pointer_cast<MachineLearningTask>(task); } template<typename ConfigType, typename TrainingTaskType, typename InferenceTaskType> diff --git a/Grinder/ml/MachineLearningTaskSpawnerBase.h b/Grinder/ml/MachineLearningTaskSpawnerBase.h index 669a46a2cdb25b6056a0a4c07f9128111072b5f0..2b9f6579f216edce5407038fa3fb8a0531b31d7c 100644 --- a/Grinder/ml/MachineLearningTaskSpawnerBase.h +++ b/Grinder/ml/MachineLearningTaskSpawnerBase.h @@ -10,13 +10,13 @@ namespace grndr { - class Task; + class MachineLearningTask; class MachineLearningTaskSpawnerBase { public: - virtual std::shared_ptr<Task> spawnTrainingTask(QString state, QString name) const = 0; - virtual std::shared_ptr<Task> spawnInferenceTask(QString state, QString name) const = 0; + virtual std::shared_ptr<MachineLearningTask> spawnTrainingTask(QString state, QString name) const = 0; + virtual std::shared_ptr<MachineLearningTask> spawnInferenceTask(QString state, QString name) const = 0; }; } diff --git a/Grinder/ml/barista/BaristaClassifierConfiguration.cpp b/Grinder/ml/barista/BaristaClassifierConfiguration.cpp index 0f6b6acf5ff1270eb28c1ba886b9fe1312d4c166..29846da391c396743adb5e34b3cb3b3d4f903e5f 100644 --- a/Grinder/ml/barista/BaristaClassifierConfiguration.cpp +++ b/Grinder/ml/barista/BaristaClassifierConfiguration.cpp @@ -19,6 +19,8 @@ BaristaClassifierConfiguration::BaristaClassifierConfiguration() void BaristaClassifierConfiguration::verifyConfiguration() const { + MachineLearningConfiguration::verifyConfiguration(); + // Verify general settings if (_baristaPort == 0) throw MachineLearningException{_EXCPT("Barista port may not be 0")}; diff --git a/Grinder/ml/barista/BaristaClassifierTaskSpawner.impl.h b/Grinder/ml/barista/BaristaClassifierTaskSpawner.impl.h index 57907f282ba7ffdd0a4162207cbfca564847f4e5..fb5a6e18da5ea5038addbab99ef53f763aafc148 100644 --- a/Grinder/ml/barista/BaristaClassifierTaskSpawner.impl.h +++ b/Grinder/ml/barista/BaristaClassifierTaskSpawner.impl.h @@ -8,7 +8,7 @@ template<typename TaskType> void BaristaClassifierTaskSpawner::configureTask(TaskType* task) const -{ +{ // Apply general settings to the task task->setBaristaPort(_config.getBaristaPort()); task->setLibraryPath(_config.getLibraryPath()); @@ -16,4 +16,6 @@ void BaristaClassifierTaskSpawner::configureTask(TaskType* task) const task->setNetwork(_config.network()); task->setOutputDirectory(_config.getOutputDirectory()); task->setRemoteDirectory(_config.getRemoteDirectory()); + + task->setInputImageTags(_config.imageTags()); } diff --git a/Grinder/ml/barista/BaristaNetwork.cpp b/Grinder/ml/barista/BaristaNetwork.cpp index 033ca735e513ddb72e52b6a5af566a8a35db1a43..6b7710b898fa8c8263b8830eaf714b2129169400 100644 --- a/Grinder/ml/barista/BaristaNetwork.cpp +++ b/Grinder/ml/barista/BaristaNetwork.cpp @@ -32,9 +32,18 @@ void BaristaNetwork::compileNetwork(const BaristaNetworkContext& ctx) const for (const auto& file : networkFiles) compileNetworkFile(ctx, file); - // When training, export the selected images as training data + // Prepare training data if (ctx.getNetworkType() == BaristaNetworkContext::NetworkType::Training) - exportTrainingData(ctx); + { + // Ensure that the training data file has been created + QString trainingFile = ctx.resolveOutputFile(FILE_BARISTA_TRAINING_DATA_HDF5); + + if (!QFile::exists(trainingFile)) + throw BaristaNetworkException{this, _EXCPT(QString{"The training data file '%1' hasn't been created"}.arg(trainingFile))}; + + // Create the text file which points to the actual data file + createDataTextFile(ctx); + } } void BaristaNetwork::cleanupNetwork(const BaristaNetworkContext& ctx) const @@ -104,7 +113,7 @@ std::vector<QFileInfo> BaristaNetwork::assembleInferenceFiles() const return {_networkInfo.getInferenceNetwork()}; } -void BaristaNetwork::compileNetworkFile(const grndr::BaristaNetworkContext& ctx, const QFileInfo& fileInfo) const +void BaristaNetwork::compileNetworkFile(const BaristaNetworkContext& ctx, const QFileInfo& fileInfo) const { QString sourceFilename = fileInfo.filePath(); QFile file{sourceFilename}; @@ -130,40 +139,19 @@ void BaristaNetwork::compileNetworkFile(const grndr::BaristaNetworkContext& ctx, throw BaristaNetworkException{this, _EXCPT(QString{"Unable to open file '%1' for reading"}.arg(sourceFilename))}; } -void BaristaNetwork::exportTrainingData(const BaristaNetworkContext& ctx) const +void BaristaNetwork::createDataTextFile(const BaristaNetworkContext& ctx) const { - if (!ctx.getLabel() || !ctx.getCanvasBlock() || ctx.getImageReferences().empty()) - throw BaristaNetworkException{this, _EXCPT("Unable to export any training data")}; - - // Export the training data using the HDF5Exporter QString hdf5File = ctx.resolveOutputFile(FILE_BARISTA_TRAINING_DATA_HDF5); QString hdf5txtFile = ctx.resolveOutputFile(FILE_BARISTA_TRAINING_DATA_TXT); - try { - // Export the images - HDF5File::ExportFlags exportFlags{HDF5File::ExportFlag::ExportTags}; - - if (_networkInfo.mergeTags()) - exportFlags |= HDF5File::ExportFlag::MergeTags; - - if (_networkInfo.requiresGrayscale()) - exportFlags |= HDF5File::ExportFlag::ExportAsGrayscale; + // Create a HDF5txt file + QFile file{hdf5txtFile}; - HDF5Exporter exporter{ctx.getLabel(), ctx.getCanvasBlock(), ctx.getImageReferences(), exportFlags}; - exporter.exportProject(&grinder()->project(), hdf5File); - - // Create a HDF5txt file - QFile file{hdf5txtFile}; - - if (file.open(QIODevice::WriteOnly|QIODevice::Text|QIODevice::Truncate)) - { - QTextStream streamOut{&file}; - streamOut << ctx.resolveRemoteFile(hdf5File) << "\n"; - } - else - throw BaristaNetworkException{this, _EXCPT(QString{"Unable to open file '%1' for writing"}.arg(hdf5txtFile))}; - } catch (ExportException& e) { - // Re-throw the export exception as a BaristaNetworkException - throw BaristaNetworkException{this, _EXCPT(QString{"Unable to export the training data to '%1' (%2)"}.arg(hdf5File).arg(GetExceptionMessage(e.what())))}; + if (file.open(QIODevice::WriteOnly|QIODevice::Text|QIODevice::Truncate)) + { + QTextStream streamOut{&file}; + streamOut << ctx.resolveRemoteFile(hdf5File) << "\n"; } + else + throw BaristaNetworkException{this, _EXCPT(QString{"Unable to open file '%1' for writing"}.arg(hdf5txtFile))}; } diff --git a/Grinder/ml/barista/BaristaNetwork.h b/Grinder/ml/barista/BaristaNetwork.h index 7fd74290d12d860ba8c22f6a4bf832b417a239db..32b9593b71f5b8082dfbff89451b61233f8ae565 100644 --- a/Grinder/ml/barista/BaristaNetwork.h +++ b/Grinder/ml/barista/BaristaNetwork.h @@ -34,7 +34,7 @@ namespace grndr void compileNetworkFile(const BaristaNetworkContext& ctx, const QFileInfo& fileInfo) const; - void exportTrainingData(const BaristaNetworkContext& ctx) const; + void createDataTextFile(const BaristaNetworkContext& ctx) const; private: BaristaNetworkInfo _networkInfo; diff --git a/Grinder/ml/barista/BaristaNetworkContext.h b/Grinder/ml/barista/BaristaNetworkContext.h index 193bb38a5f546c2fb4e5b1aa95a6f2894d59d8c0..8aec5253867fe98fa9f1d5d5dc26bbff218ed553 100644 --- a/Grinder/ml/barista/BaristaNetworkContext.h +++ b/Grinder/ml/barista/BaristaNetworkContext.h @@ -51,13 +51,6 @@ namespace grndr void removeVariable(QString name) { _variables.erase(name); } void replaceVariables(QString& text) const; - Label* getLabel() const { return _label; } - void setLabel(Label* label) { _label = label; } - Block* getCanvasBlock() const { return _canvasBlock; } - void setCanvasBlock(Block* block) { _canvasBlock = block; } - const ImageReferenceSelection& getImageReferences() const { return _imageReferences; } - void setImageReferences(const ImageReferenceSelection& imageRefs) { _imageReferences = imageRefs; } - private: NetworkType _networkType; @@ -65,10 +58,6 @@ namespace grndr QString _remoteDirectory; std::map<QString, QVariant> _variables; - - Label* _label{nullptr}; - Block* _canvasBlock{nullptr}; - ImageReferenceSelection _imageReferences; }; } diff --git a/Grinder/ml/barista/BaristaNetworkInfo.h b/Grinder/ml/barista/BaristaNetworkInfo.h index 32b4f9b6bfc0b3372d8ef9442007602ec3255fc7..fce40a7acdec223748b57ce32278d4522a7d6e4e 100644 --- a/Grinder/ml/barista/BaristaNetworkInfo.h +++ b/Grinder/ml/barista/BaristaNetworkInfo.h @@ -23,7 +23,6 @@ namespace grndr public: QString getNetworkName() const { return _settings.value("BaristaNet/Name").toString(); } - bool mergeTags() const { return _settings.value("BaristaNet.Settings/MergeTags").toBool(); } bool requiresGrayscale() const { return _settings.value("BaristaNet.Settings/Grayscale").toBool(); } QString getTrainingSolver(bool getFullPath = true) const { return getFilename("Training/Solver", getFullPath); } diff --git a/Grinder/ml/barista/blocks/BaristaClassifierBlock.cpp b/Grinder/ml/barista/blocks/BaristaClassifierBlock.cpp index ce428c5d9710f3ef619612e829550c44f2e8285d..2821d8b2817d7d79d37a6592376ffc49584005c5 100644 --- a/Grinder/ml/barista/blocks/BaristaClassifierBlock.cpp +++ b/Grinder/ml/barista/blocks/BaristaClassifierBlock.cpp @@ -66,6 +66,8 @@ bool BaristaClassifierBlock::updateProperties(PropertyBase* updatedProp) void BaristaClassifierBlock::updateConfiguration() { + MachineLearningMethodBlock::updateConfiguration(); + _method.config().setBaristaPort(*baristaPort()); _method.config().setLibraryPath(*libraryPath()); _method.config().setNetwork(*network()); diff --git a/Grinder/ml/barista/tasks/BaristaInferenceTask.cpp b/Grinder/ml/barista/tasks/BaristaInferenceTask.cpp index db3b4010507259223703afbb66a6f5cf512f3581..1174f39ad690ec9ccbba4030063a0e5601ab1e0c 100644 --- a/Grinder/ml/barista/tasks/BaristaInferenceTask.cpp +++ b/Grinder/ml/barista/tasks/BaristaInferenceTask.cpp @@ -102,7 +102,7 @@ void BaristaInferenceTask::sendLoadNetworkMessage() void BaristaInferenceTask::sendInferImageMessage(unsigned int inferImageIndex) { - if (inferImageIndex < _imageReferences.size()) // Any images left to infer? +/* if (inferImageIndex < _imageReferences.size()) // Any images left to infer? { auto imageRef = _imageReferences[inferImageIndex]; addLogMessage(QString{"\tPerforming inference on '%1'..."}.arg(imageRef->getImageFileName())); @@ -142,7 +142,7 @@ void BaristaInferenceTask::sendInferImageMessage(unsigned int inferImageIndex) // The inference has finished, so break the Barista connection and finish the task shutdownBaristaTask(); finishTask(true); - } + }*/ } std::unique_ptr<BaristaMessage> BaristaInferenceTask::handleLoadNetworkMessage(BaristaMessage* message) @@ -172,7 +172,7 @@ std::unique_ptr<BaristaMessage> BaristaInferenceTask::handleLoadNetworkMessage(B } std::unique_ptr<BaristaMessage> BaristaInferenceTask::handleInferImageMessage(BaristaMessage* message) -{ +{/* if (_taskState == InferenceTaskState::InferImages) { if (message->getStatus()) @@ -203,14 +203,14 @@ std::unique_ptr<BaristaMessage> BaristaInferenceTask::handleInferImageMessage(Ba } else reportUnexpectedMessage(message); - +*/ return nullptr; } void BaristaInferenceTask::processInferenceResult_Probabilities(const ImageReference* imageRef, const BaristaInferImageMessage* result, QString outputName) { auto dims = result->getDimensions(outputName); - +/* if (dims.size() == 4 && dims[0] == 1) // Four dimensions: # of images (must be 1), # of labels, height, width { // Execute the label so that an image build is created @@ -242,6 +242,7 @@ void BaristaInferenceTask::processInferenceResult_Probabilities(const ImageRefer } else throw BaristaException{_EXCPT("Invalid dimensions")}; + */ } void BaristaInferenceTask::createProbabilityItems(ImageBuild* imageBuild, cv::Mat& probData, QSize dataSize, int tagCount, const ImageTagVector* imageTags) diff --git a/Grinder/ml/barista/tasks/BaristaInferenceTask.h b/Grinder/ml/barista/tasks/BaristaInferenceTask.h index c25d3d24d2dfcf5d0ed1deaaf2e17ca20a01d44d..acab89c19bdd32517ba9ae13130f556a83e13cf4 100644 --- a/Grinder/ml/barista/tasks/BaristaInferenceTask.h +++ b/Grinder/ml/barista/tasks/BaristaInferenceTask.h @@ -85,9 +85,9 @@ namespace grndr private: void processInferenceResult_Probabilities(const ImageReference* imageRef, const BaristaInferImageMessage* result, QString outputName); - void createProbabilityItems(ImageBuild* imageBuild, cv::Mat& probData, QSize dataSize, int tagCount, const ImageTagVector* imageTags); + void createProbabilityItems(ImageBuild* imageBuild, cv::Mat& probData, QSize dataSize, int tagCount, const ImageTagVector* inputImageTags); void createProbabilityItems(Layer* layer, cv::Mat& probData, QSize dataSize, int tagIndex, ImageTag* imageTag); - void createProbabilityMaps(ImageBuild* imageBuild, cv::Mat& probData, QSize dataSize, int tagCount, const ImageTagVector* imageTags); + void createProbabilityMaps(ImageBuild* imageBuild, cv::Mat& probData, QSize dataSize, int tagCount, const ImageTagVector* inputImageTags); void createProbabilityMap(ImageBuild* imageBuild, cv::Mat& probData, QSize dataSize, int tagIndex, ImageTag* imageTag); cv::Mat extractProbabilityData(cv::Mat& probData, QSize dataSize, int index) const; diff --git a/Grinder/ml/barista/tasks/BaristaTask.h b/Grinder/ml/barista/tasks/BaristaTask.h index 86de15206aa96f40321745bc0c94c4ae42ad8494..759e8876d2ddeb6fcb351137616422eefce9b6d0 100644 --- a/Grinder/ml/barista/tasks/BaristaTask.h +++ b/Grinder/ml/barista/tasks/BaristaTask.h @@ -9,13 +9,13 @@ #include "ml/barista/BaristaInterface.h" #include "ml/barista/BaristaNetwork.h" #include "ml/barista/BaristaNetworkContext.h" -#include "task/Task.h" +#include "ml/tasks/MachineLearningTask.h" #include "project/ImageReferenceSelection.h" namespace grndr { template<typename ClassType> - class BaristaTask : public Task, public NetworkMessageHandler<ClassType, BaristaMessage> + class BaristaTask : public MachineLearningTask, public NetworkMessageHandler<ClassType, BaristaMessage> { public: static const char* Serialization_Value_BaristaPort; @@ -23,8 +23,6 @@ namespace grndr static const char* Serialization_Value_Network; static const char* Serialization_Value_OutputDirectory; static const char* Serialization_Value_RemoteDirectory; - static const char* Serialization_Value_Label; - static const char* Serialization_Value_CanvasBlock; public: using class_type = ClassType; @@ -51,12 +49,9 @@ namespace grndr QString getRemoteDirectory() const { return _remoteDirectory; } void setRemoteDirectory(QString dir) { _remoteDirectory = dir; } - Label* getLabel() const { return _label; } - void setLabel(Label* label); - Block* getCanvasBlock() const { return _canvasBlock; } - void setCanvasBlock(Block* block); - const ImageReferenceSelection& getImageReferences() const { return _imageReferences; } - void setImageReferences(const ImageReferenceSelection& imageRefs); + public: + const ImageTags* inputImageTags() const { return _inputImageTags; } + void setInputImageTags(const ImageTags* imageTags) { _inputImageTags = imageTags; } public: virtual void serialize(SerializationContext& ctx) const override; @@ -69,7 +64,7 @@ namespace grndr virtual void stop() override; virtual void update() override; - virtual void finish(bool succeeded) override; + virtual void finish(bool succeeded) override; protected: void initializeBaristaTask(); @@ -91,9 +86,6 @@ namespace grndr void baristaClientReady(QString ip); void baristaClientFailure(QString error) { reportError(error); } - void labelRemoved(const std::shared_ptr<Label>& label); - void blockRemoved(const std::shared_ptr<Block>& block); - protected: unsigned int _baristaPort{6980}; @@ -104,9 +96,8 @@ namespace grndr QString _outputDirectory{""}; QString _remoteDirectory{""}; - Label* _label{nullptr}; - Block* _canvasBlock{nullptr}; - ImageReferenceSelection _imageReferences; + protected: + const ImageTags* _inputImageTags{nullptr}; protected: BaristaInterface _baristaInterface; diff --git a/Grinder/ml/barista/tasks/BaristaTask.impl.h b/Grinder/ml/barista/tasks/BaristaTask.impl.h index 11a7703fafcb058576d9c0175856d3f096ad5c90..7f4d1ffee157637bfd5c8dee59a90f3fe9f0cb45 100644 --- a/Grinder/ml/barista/tasks/BaristaTask.impl.h +++ b/Grinder/ml/barista/tasks/BaristaTask.impl.h @@ -21,19 +21,15 @@ template<typename ClassType> const char* BaristaTask<ClassType>::Serialization_Value_OutputDirectory = "OutputDirectory"; template<typename ClassType> const char* BaristaTask<ClassType>::Serialization_Value_RemoteDirectory = "RemoteDirectory"; -template<typename ClassType> -const char* BaristaTask<ClassType>::Serialization_Value_Label = "Label"; -template<typename ClassType> -const char* BaristaTask<ClassType>::Serialization_Value_CanvasBlock = "CanvasBlock"; template<typename ClassType> -BaristaTask<ClassType>::BaristaTask(class_type* handlerTarget, TaskPool* taskPool, TaskType type, QString name) : Task(taskPool, type, Task::Capability::CanBeStopped|Task::Capability::HasProgress, name), NetworkMessageHandler<ClassType, BaristaMessage>(&_baristaInterface, handlerTarget) +BaristaTask<ClassType>::BaristaTask(class_type* handlerTarget, TaskPool* taskPool, TaskType type, QString name) : MachineLearningTask(taskPool, type, Task::Capability::CanBeStopped|Task::Capability::HasProgress, name), NetworkMessageHandler<ClassType, BaristaMessage>(&_baristaInterface, handlerTarget) { // Register the task as a message handler _baristaInterface.registerMessageHandler(this); // When the task has been stopped, immediately finish it - connect(this, &Task::taskStopped, [this]() { finishTask(false); }); + connect(this, &MachineLearningTask::taskStopped, [this]() { finishTask(false); }); // The Barista interface will notify us when it is ready or an error occurred connect(&_baristaInterface, &BaristaInterface::clientReady, this, &BaristaTask<ClassType>::baristaClientReady); @@ -52,42 +48,10 @@ void BaristaTask<ClassType>::setNetwork(BaristaNetwork* network) _network = network; } -template<typename ClassType> -void BaristaTask<ClassType>::setLabel(Label* label) -{ - if (_label) - disconnect(&grinder()->project(), nullptr, this, nullptr); - - _label = label; - - // Listen to removed labels to reset the assigned label if necessary - if (_label) - connect(&grinder()->project(), &Project::labelRemoved, this, &BaristaTask<ClassType>::labelRemoved); -} - -template<typename ClassType> -void BaristaTask<ClassType>::setCanvasBlock(Block* block) -{ - if (_canvasBlock) - disconnect(_canvasBlock->pipeline(), nullptr, this, nullptr); - - _canvasBlock = block; - - // Listen to removed blocks to reset the assigned canvas block if necessary - if (_canvasBlock) - connect(_canvasBlock->pipeline(), &Pipeline::blockRemoved, this, &BaristaTask<ClassType>::blockRemoved); -} - -template<typename ClassType> -void BaristaTask<ClassType>::setImageReferences(const ImageReferenceSelection& imageRefs) -{ - _imageReferences = imageRefs; -} - template<typename ClassType> void BaristaTask<ClassType>::serialize(SerializationContext& ctx) const { - Task::serialize(ctx); + MachineLearningTask::serialize(ctx); // Serialize values ctx.settings()(Serialization_Value_BaristaPort) = _baristaPort; @@ -95,17 +59,12 @@ void BaristaTask<ClassType>::serialize(SerializationContext& ctx) const ctx.settings()(Serialization_Value_Network) = _network ? _network->networkInfo().getNetworkName() : QString{""}; ctx.settings()(Serialization_Value_OutputDirectory) = _outputDirectory; ctx.settings()(Serialization_Value_RemoteDirectory) = _remoteDirectory; - ctx.settings()(Serialization_Value_Label) = ctx.getLabelIndex(_label); - ctx.settings()(Serialization_Value_CanvasBlock) = ctx.getBlockIndex(_canvasBlock); - - // Serialize image references - _imageReferences.serialize(ctx); } template<typename ClassType> void BaristaTask<ClassType>::deserialize(DeserializationContext& ctx) { - Task::deserialize(ctx); + MachineLearningTask::deserialize(ctx); // Deserialize values QString networkName = ctx.settings()(Serialization_Value_Network).toString(); @@ -115,17 +74,12 @@ void BaristaTask<ClassType>::deserialize(DeserializationContext& ctx) setNetwork(!networkName.isEmpty() ? grinder()->externalDataManager().baristaNetworkPool().findNetwork(networkName) : nullptr); _outputDirectory = ctx.settings()(Serialization_Value_OutputDirectory).toString(); _remoteDirectory = ctx.settings()(Serialization_Value_RemoteDirectory).toString(); - setLabel(ctx.getLabel(ctx.settings()(Serialization_Value_Label, -1).toInt())); - setCanvasBlock(ctx.getBlock(ctx.settings()(Serialization_Value_CanvasBlock, -1).toInt())); - - // Deserialize image references - _imageReferences.deserialize(ctx); } template<typename ClassType> void BaristaTask<ClassType>::verifyTask() const { - Task::verifyTask(); + MachineLearningTask::verifyTask(); if (_libraryPath.isEmpty()) throw TaskException{this, _EXCPT("No library path provided")}; @@ -135,15 +89,6 @@ void BaristaTask<ClassType>::verifyTask() const if (_outputDirectory.isEmpty()) throw TaskException{this, _EXCPT("No output directory provided")}; - - if (!_label) - throw TaskException{this, _EXCPT("No label provided")}; - - if (!_canvasBlock) - throw TaskException{this, _EXCPT("No canvas block provided")}; - - if (_imageReferences.empty()) - throw TaskException{this, _EXCPT("No images selected")}; } template<typename ClassType> @@ -151,7 +96,7 @@ void BaristaTask<ClassType>::execute() { initializeBaristaTask(); - Task::execute(); + MachineLearningTask::execute(); } template<typename ClassType> @@ -159,7 +104,7 @@ void BaristaTask<ClassType>::stop() { shutdownBaristaTask(); - Task::stop(); + MachineLearningTask::stop(); } template<typename ClassType> @@ -250,22 +195,10 @@ void BaristaTask<ClassType>::reportUnexpectedMessage(BaristaMessage* message) template<typename ClassType> void BaristaTask<ClassType>::prepareBaristaNetworkContext() { - // Assign some variables - _networkContext->setLabel(_label); - _networkContext->setCanvasBlock(_canvasBlock); - _networkContext->setImageReferences(_imageReferences); - // Output count (based on available image tags) - unsigned int tagsCount = 0; - - if (_canvasBlock) - { - if (auto imageTagsProperty = _canvasBlock->portProperty<ImageTagsProperty>(PortType::ImageTagsIn, PropertyID::ImageTags)) - tagsCount = imageTagsProperty->object().tags().size(); - } - - _networkContext->addVariable(BaristaNetworkContext::Variable_OutputCount, tagsCount); - _networkContext->addVariable(BaristaNetworkContext::Variable_OutputCountPlus1, tagsCount + 1); + unsigned int imageTagsCount = _inputImageTags ? _inputImageTags->tags().size() : 0; + _networkContext->addVariable(BaristaNetworkContext::Variable_OutputCount, imageTagsCount); + _networkContext->addVariable(BaristaNetworkContext::Variable_OutputCountPlus1, imageTagsCount + 1); } template<typename ClassType> @@ -295,17 +228,3 @@ void BaristaTask<ClassType>::changeTaskState(int state, QString msg) addLogMessage(msg); } } - -template<typename ClassType> -void BaristaTask<ClassType>::labelRemoved(const std::shared_ptr<Label>& label) -{ - if (label.get() == _label) - setLabel(nullptr); -} - -template<typename ClassType> -void BaristaTask<ClassType>::blockRemoved(const std::shared_ptr<Block>& block) -{ - if (block.get() == _canvasBlock) - setCanvasBlock(nullptr); -} diff --git a/Grinder/ml/barista/tasks/BaristaTrainingTask.cpp b/Grinder/ml/barista/tasks/BaristaTrainingTask.cpp index 61352862383f1124dd5b91c652ab1062797f76d5..24b9a1c54edb97bf479222a0412e45988771aa67 100644 --- a/Grinder/ml/barista/tasks/BaristaTrainingTask.cpp +++ b/Grinder/ml/barista/tasks/BaristaTrainingTask.cpp @@ -8,6 +8,7 @@ #include "ml/barista/BaristaMessage.h" #include "task/TaskExceptions.h" #include "pipeline/Block.h" +#include "engine/EngineExecutionContext.h" #include "image/properties/ImageTagsProperty.h" #include "ui/barista/tasks/BaristaTrainingTaskWidget.h" #include "res/Filenames.h" @@ -41,6 +42,46 @@ ConfigureTaskWidgetBase* BaristaTrainingTask::createEditor(bool newTask, QWidget return new BaristaTrainingTaskWidget{this, newTask, parent}; } +void BaristaTrainingTask::processEngineStart(EngineExecutionContext& ctx, const MachineLearningTaskData& data) +{ + BaristaTask::processEngineStart(ctx, data); + + // Create the HDF5 file in the output directory + QFileInfo fi{_outputDirectory, FILE_BARISTA_TRAINING_DATA_HDF5}; + _h5Export = std::make_unique<HDF5Export>(fi.filePath()); + + HDF5Export::ExportFlags exportFlags = HDF5Export::ExportFlag::ExportTags|HDF5Export::ExportFlag::MergeTags; + + if (_network->networkInfo().requiresGrayscale()) + exportFlags |= HDF5Export::ExportFlag::ExportAsGrayscale; + + _h5Export->initExport(QSize{data.imageData.cols, data.imageData.rows}, ctx.imageReferences().size(), 1, exportFlags); +} + +void BaristaTrainingTask::processEnginePass(EngineExecutionContext& ctx, const MachineLearningTaskData& data) +{ + BaristaTask::processEnginePass(ctx, data); + + if (!_h5Export) + { + reportError("The HDF5 file isn't ready for exporting"); + return; + } + + if (!data.imageTagsData.empty()) + _h5Export->exportImageEx(data.imageData, {data.imageTagsData}); + else + _h5Export->exportImage(data.imageData); +} + +void BaristaTrainingTask::processEngineEnd(EngineExecutionContext& ctx, const MachineLearningTaskData& data) +{ + BaristaTask::processEngineEnd(ctx, data); + + // We no longer need the HDF5 export + _h5Export = nullptr; +} + void BaristaTrainingTask::serialize(SerializationContext& ctx) const { BaristaTask::serialize(ctx); @@ -61,22 +102,6 @@ void BaristaTrainingTask::deserialize(DeserializationContext& ctx) _snapshotInterval = ctx.settings()(Serialization_Value_SnapshotInterval, 1000).toUInt(); } -void BaristaTrainingTask::verifyTask() const -{ - BaristaTask::verifyTask(); - - if (_canvasBlock) - { - bool hasTags = false; - - if (auto imageTagsProperty = _canvasBlock->portProperty<ImageTagsProperty>(PortType::ImageTagsIn, PropertyID::ImageTags)) - hasTags = !imageTagsProperty->object().tags().empty(); - - if (!hasTags) - throw TaskException{this, _EXCPT("No image tags provided")}; - } -} - void BaristaTrainingTask::createBaristaNetworkContext() { _networkContext = std::make_unique<BaristaNetworkContext>(BaristaNetworkContext::NetworkType::Training, _outputDirectory, _remoteDirectory); diff --git a/Grinder/ml/barista/tasks/BaristaTrainingTask.h b/Grinder/ml/barista/tasks/BaristaTrainingTask.h index 26c55dfd4a19d720167fc167058793b403a5fedd..3fda464e8b8a8075f397e9be87e3f75e431a555f 100644 --- a/Grinder/ml/barista/tasks/BaristaTrainingTask.h +++ b/Grinder/ml/barista/tasks/BaristaTrainingTask.h @@ -7,6 +7,7 @@ #define BARISTATRAININGTASK_H #include "BaristaTask.h" +#include "project/exporters/HDF5Export.h" namespace grndr { @@ -30,6 +31,11 @@ namespace grndr public: virtual ConfigureTaskWidgetBase* createEditor(bool newTask, QWidget* parent) override; + public: + virtual void processEngineStart(EngineExecutionContext& ctx, const MachineLearningTaskData& data) override; + virtual void processEnginePass(EngineExecutionContext& ctx, const MachineLearningTaskData& data) override; + virtual void processEngineEnd(EngineExecutionContext& ctx, const MachineLearningTaskData& data) override; + public: unsigned int getMaxIterations() const { return _maxIterations; } void setMaxIterations(unsigned int maxIter) { _maxIterations = maxIter; } @@ -42,9 +48,6 @@ namespace grndr virtual void serialize(SerializationContext& ctx) const override; virtual void deserialize(DeserializationContext& ctx) override; - protected: - virtual void verifyTask() const override; - protected: virtual void createBaristaNetworkContext() override; virtual void prepareBaristaNetworkContext() override; @@ -70,6 +73,9 @@ namespace grndr unsigned int _maxIterations{5000}; unsigned int _displayInterval{100}; unsigned int _snapshotInterval{1000}; + + private: + std::unique_ptr<HDF5Export> _h5Export; }; } diff --git a/Grinder/ml/blocks/MachineLearningMethodBlock.h b/Grinder/ml/blocks/MachineLearningMethodBlock.h index 9cf873b7bb3083c79bf1d86ad0593ed5399a8474..e82604750686d011a3ea5522708d1a3d72403e54 100644 --- a/Grinder/ml/blocks/MachineLearningMethodBlock.h +++ b/Grinder/ml/blocks/MachineLearningMethodBlock.h @@ -38,6 +38,8 @@ namespace grndr auto state() { return dynamic_cast<MachineLearningStateProperty*>(_state.get()); } auto state() const { return dynamic_cast<const MachineLearningStateProperty*>(_state.get()); } + Port* imageTagsPort() { return _imageTagsPort.get(); } + const Port* imageTagsPort() const { return _imageTagsPort.get(); } Port* methodPort() { return _methodPort.get(); } const Port* methodPort() const { return _methodPort.get(); } Port* statePort() { return _statePort.get(); } @@ -49,17 +51,21 @@ namespace grndr virtual void createPorts() override; protected: - virtual void updateConfiguration() = 0; + virtual void updateConfiguration(); protected: bool checkStateAvailability(); + private: + void imageTagsPortConnectionChanged(const Connection* connection) { Q_UNUSED(connection); updateConfiguration(); } + protected: method_type _method; protected: std::shared_ptr<PropertyBase> _state; + std::shared_ptr<Port> _imageTagsPort; std::shared_ptr<Port> _methodPort; std::shared_ptr<Port> _statePort; }; diff --git a/Grinder/ml/blocks/MachineLearningMethodBlock.impl.h b/Grinder/ml/blocks/MachineLearningMethodBlock.impl.h index fef7a1847ed8161add7cce1283b3124323b95fbe..e4393eae5fb97f606ecee2518ccee9768e52f984 100644 --- a/Grinder/ml/blocks/MachineLearningMethodBlock.impl.h +++ b/Grinder/ml/blocks/MachineLearningMethodBlock.impl.h @@ -18,6 +18,10 @@ void MachineLearningMethodBlock<MethodType>::initBlock() { Block::initBlock(); + // Listen to connection changes on the image tags port + connect(_imageTagsPort.get(), &Port::portConnected, this, &MachineLearningMethodBlock<MethodType>::imageTagsPortConnectionChanged); + connect(_imageTagsPort.get(), &Port::portDisconnected, this, &MachineLearningMethodBlock<MethodType>::imageTagsPortConnectionChanged); + updateConfiguration(); } @@ -53,6 +57,9 @@ bool MachineLearningMethodBlock<MethodType>::updateProperties(PropertyBase* upda template<typename MethodType> void MachineLearningMethodBlock<MethodType>::createPorts() { + DataDescriptors imageTagsPortDataDescs = {DataDescriptor::customDescriptor("Image tags", DataType::ImageTags)}; + _imageTagsPort = createPort(PortType::ImageTagsIn, Port::Direction::In, imageTagsPortDataDescs, "Tags"); + DataDescriptors methodPortDataDescs = {DataDescriptor::customDescriptor("Machine learning method", DataType::MachineLearningMethod)}; _methodPort = createPort(PortType::Method, Port::Direction::Out, methodPortDataDescs, "Method"); @@ -60,6 +67,15 @@ void MachineLearningMethodBlock<MethodType>::createPorts() _statePort = createPort(PortType::State, Port::Direction::Out, statePortDataDescs, "State"); } +template<typename MethodType> +void MachineLearningMethodBlock<MethodType>::updateConfiguration() +{ + if (auto imageTagsProperty = this->template portProperty<ImageTagsProperty>(PortType::ImageTagsIn, PropertyID::ImageTags)) + _method.config().setImageTags(&imageTagsProperty->object()); + else + _method.config().setImageTags(nullptr); +} + template<typename MethodType> bool MachineLearningMethodBlock<MethodType>::checkStateAvailability() { diff --git a/Grinder/ml/processors/MachineLearningMethodProcessor.impl.h b/Grinder/ml/processors/MachineLearningMethodProcessor.impl.h index 9c88e628246150dbc2789d957cdc8f1e35637ee5..94a7df737070391c9ac53963f8fd89366faf028e 100644 --- a/Grinder/ml/processors/MachineLearningMethodProcessor.impl.h +++ b/Grinder/ml/processors/MachineLearningMethodProcessor.impl.h @@ -6,6 +6,7 @@ #include "Grinder.h" #include "MachineLearningMethodProcessor.h" #include "ml/MachineLearningMethodBase.h" +#include "image/properties/ImageTagsProperty.h" Q_DECLARE_METATYPE(const MachineLearningMethodBase*) diff --git a/Grinder/ml/processors/MachineLearningProcessor.h b/Grinder/ml/processors/MachineLearningProcessor.h index 067c813cd2f33a94d0c5f3394d6a1b197e1134ca..cd34cfbc1e57c1c2525ecc1116362c165313dd19 100644 --- a/Grinder/ml/processors/MachineLearningProcessor.h +++ b/Grinder/ml/processors/MachineLearningProcessor.h @@ -7,12 +7,12 @@ #define MACHINELEARNINGPROCESSOR_H #include "engine/Processor.h" +#include "ml/tasks/MachineLearningTask.h" namespace grndr { class MachineLearningBlock; class MachineLearningMethodBase; - class Task; template<typename BlockType> class MachineLearningProcessor : public Processor<BlockType> @@ -22,6 +22,8 @@ namespace grndr private: static const char* Data_Value_Method; static const char* Data_Value_State; + static const char* Data_Value_ImageTags; + static const char* Data_Value_SpawnedTask; protected: enum class SpawnType @@ -31,19 +33,28 @@ namespace grndr }; public: - using Processor<BlockType>::Processor; + MachineLearningProcessor(const Block* block, SpawnType spawnType, bool requiresBatchMode = false); public: virtual void execute(EngineExecutionContext& ctx) override; protected: - virtual void execute(EngineExecutionContext& ctx, const MachineLearningMethodBase* method, QString state) = 0; + virtual bool execute(EngineExecutionContext& ctx, const MachineLearningMethodBase* method, QString state) { Q_UNUSED(ctx); Q_UNUSED(method); Q_UNUSED(state); return false; } - protected: - std::shared_ptr<Task> spawnTask(SpawnType type, const MachineLearningMethodBase* method, QString state) const; + virtual void fillTaskData(EngineExecutionContext& ctx, const MachineLearningMethodBase* method, MachineLearningTaskData& taskData) const; + + private: + void spawnTask(const MachineLearningMethodBase* method, QString state); private: QString getSpawnTypeName(SpawnType type) const; + + protected: + std::shared_ptr<MachineLearningTask> _spawnedTask; + + private: + SpawnType _spawnType{SpawnType::Training}; + bool _requiresBatchMode{false}; }; } diff --git a/Grinder/ml/processors/MachineLearningProcessor.impl.h b/Grinder/ml/processors/MachineLearningProcessor.impl.h index d130f029c613fcb02e81da3b472f10387965f3f5..2fa4f3e6ac654caddd9969e437894e6a0f5ac088 100644 --- a/Grinder/ml/processors/MachineLearningProcessor.impl.h +++ b/Grinder/ml/processors/MachineLearningProcessor.impl.h @@ -7,13 +7,27 @@ #include "MachineLearningProcessor.h" #include "ml/MachineLearningMethodBase.h" #include "ml/MachineLearningTaskSpawnerBase.h" +#include "ml/MachineLearningConfiguration.h" +#include "image/ImageTags.h" template<typename BlockType> const char* MachineLearningProcessor<BlockType>::Data_Value_Method = "Method"; template<typename BlockType> const char* MachineLearningProcessor<BlockType>::Data_Value_State = "State"; +template<typename BlockType> +const char* MachineLearningProcessor<BlockType>::Data_Value_ImageTags = "ImageTags"; +template<typename BlockType> +const char* MachineLearningProcessor<BlockType>::Data_Value_SpawnedTask = "SpawnedTask"; Q_DECLARE_METATYPE(const MachineLearningMethodBase*) +Q_DECLARE_METATYPE(std::shared_ptr<MachineLearningTask>) + +template<typename BlockType> +MachineLearningProcessor<BlockType>::MachineLearningProcessor(const Block* block, MachineLearningProcessor::SpawnType spawnType, bool requiresBatchMode) : Processor<BlockType>(block), + _spawnType{spawnType}, _requiresBatchMode{requiresBatchMode} +{ + +} template<typename BlockType> void MachineLearningProcessor<BlockType>::execute(EngineExecutionContext& ctx) @@ -24,35 +38,86 @@ void MachineLearningProcessor<BlockType>::execute(EngineExecutionContext& ctx) const MachineLearningMethodBase* method = this->template portData<const MachineLearningMethodBase*>(ctx, this->_block->methodPort(), Data_Value_Method); QString state = this->template portData<QString>(ctx, this->_block->statePort(), Data_Value_State, false); - execute(ctx, method, state); + if (!execute(ctx, method, state)) + { + // Take care of the machine learning task + if (!_requiresBatchMode || ctx.hasExecutionFlag(Engine::ExecutionFlag::Batch)) + { + // Spawn the task when the first image is active + if (ctx.isFirstImage()) + { + spawnTask(method, state); + ctx.persistentData().set(this->_block, Data_Value_SpawnedTask, _spawnedTask); // Store the spawned task in the persistent context data + } + else + { + _spawnedTask = ctx.persistentData().get<std::shared_ptr<MachineLearningTask>>(this->block(), Data_Value_SpawnedTask); // Retrieve the stored spawned task from the persistent context data + + if (!_spawnedTask) + this->throwProcessorException("No machine learning task has been spawned"); + } + + // Get the task data for machine learning + MachineLearningTaskData taskData; + fillTaskData(ctx, method, taskData); + + if (ctx.isFirstImage()) + _spawnedTask->processEngineStart(ctx, taskData); + + _spawnedTask->processEnginePass(ctx, taskData); + + if (ctx.isLastImage()) + { + _spawnedTask->processEngineEnd(ctx, taskData); + + // Forget the spawned task + _spawnedTask = nullptr; + ctx.persistentData().remove(this->_block, Data_Value_SpawnedTask); + } + } + else + { + QString name = getSpawnTypeName(_spawnType); + this->throwProcessorException(QString{"%1 is only possible in batch mode; bypass the %2 block to avoid this warning"}.arg(name).arg(name.toLower())); + } + } } } template<typename BlockType> -std::shared_ptr<Task> MachineLearningProcessor<BlockType>::spawnTask(SpawnType type, const MachineLearningMethodBase* method, QString state) const +void MachineLearningProcessor<BlockType>::fillTaskData(EngineExecutionContext& ctx, const MachineLearningMethodBase* method, MachineLearningTaskData& taskData) const { + Q_UNUSED(method); + + if (auto dataBlob = this->portData(ctx, this->_block->inPort())) + taskData.imageData = dataBlob->getMatrix(); +} + +template<typename BlockType> +void MachineLearningProcessor<BlockType>::spawnTask(const MachineLearningMethodBase* method, QString state) +{ + _spawnedTask = nullptr; + if (auto spawner = method->createTaskSpawner()) { - QString taskName = QString{"%1 %2 (%3)"}.arg(method->getMethodName()).arg(getSpawnTypeName(type)).arg(this->_block->getFormattedName()); - std::shared_ptr<Task> task; + QString taskName = QString{"%1 %2 (%3)"}.arg(method->getMethodName()).arg(getSpawnTypeName(_spawnType)).arg(this->_block->getFormattedName()); - switch (type) + switch (_spawnType) { case SpawnType::Training: - task = spawner->spawnTrainingTask(state, taskName); + _spawnedTask = spawner->spawnTrainingTask(state, taskName); break; case SpawnType::Inference: - task = spawner->spawnInferenceTask(state, taskName); + _spawnedTask = spawner->spawnInferenceTask(state, taskName); break; } - return task; + if (!_spawnedTask) + this->throwProcessorException("Unable to spawn the machine learning task"); } else this->throwProcessorException("Unable to create a task spawner"); - - return nullptr; } template<typename BlockType> diff --git a/Grinder/ml/processors/TrainingProcessor.cpp b/Grinder/ml/processors/TrainingProcessor.cpp index cff78977efa5cf42c8787c6a0dc6103db0748570..bef117ffe58848954ea23e590e79b4de9c46294a 100644 --- a/Grinder/ml/processors/TrainingProcessor.cpp +++ b/Grinder/ml/processors/TrainingProcessor.cpp @@ -7,23 +7,15 @@ #include "TrainingProcessor.h" #include "ml/MachineLearningTaskSpawnerBase.h" -TrainingProcessor::TrainingProcessor(const Block* block) : MachineLearningProcessor(block) +TrainingProcessor::TrainingProcessor(const Block* block) : MachineLearningProcessor(block, SpawnType::Training, true) { } -void TrainingProcessor::execute(EngineExecutionContext& ctx, const MachineLearningMethodBase* method, QString state) +void TrainingProcessor::fillTaskData(EngineExecutionContext& ctx, const MachineLearningMethodBase* method, MachineLearningTaskData& taskData) const { - // Training is only executed in batch mode - if (ctx.hasExecutionFlag(Engine::ExecutionFlag::Batch)) - { - // Spawn the training task when the first image is active - if (ctx.isFirstImage()) - { - // TODO: Gather images, labels etc. before spawning - spawnTask(SpawnType::Training, method, state); - } - } - else - throwProcessorException("Training is only possible in batch mode; bypass the training block to avoid this warning"); + MachineLearningProcessor::fillTaskData(ctx, method, taskData); + + if (auto dataBlob = portData(ctx, _block->tagsBitmapPort())) + taskData.imageTagsData = dataBlob->getMatrix(); } diff --git a/Grinder/ml/processors/TrainingProcessor.h b/Grinder/ml/processors/TrainingProcessor.h index 5dde33bf14840cc412e34b3a8c09190a0098c7fb..9d896e177d56286e38097e4abef42b17e52ac7ad 100644 --- a/Grinder/ml/processors/TrainingProcessor.h +++ b/Grinder/ml/processors/TrainingProcessor.h @@ -17,7 +17,7 @@ namespace grndr TrainingProcessor(const Block* block); protected: - virtual void execute(EngineExecutionContext& ctx, const MachineLearningMethodBase* method, QString state) override; + virtual void fillTaskData(EngineExecutionContext& ctx, const MachineLearningMethodBase* method, MachineLearningTaskData& taskData) const override; }; } diff --git a/Grinder/ml/tasks/MachineLearningTask.cpp b/Grinder/ml/tasks/MachineLearningTask.cpp new file mode 100644 index 0000000000000000000000000000000000000000..413adc18c666e2360f5f68ea56d56943fe4753f2 --- /dev/null +++ b/Grinder/ml/tasks/MachineLearningTask.cpp @@ -0,0 +1,7 @@ +/****************************************************************************** + * File: MachineLearningTask.cpp + * Date: 27.8.2019 + *****************************************************************************/ + +#include "Grinder.h" +#include "MachineLearningTask.h" diff --git a/Grinder/ml/tasks/MachineLearningTask.h b/Grinder/ml/tasks/MachineLearningTask.h new file mode 100644 index 0000000000000000000000000000000000000000..c3ea9b65825e7fc3943b4277f70bc56b18e1d582 --- /dev/null +++ b/Grinder/ml/tasks/MachineLearningTask.h @@ -0,0 +1,30 @@ +/****************************************************************************** + * File: MachineLearningTask.h + * Date: 27.8.2019 + *****************************************************************************/ + +#ifndef MACHINELEARNINGTASK_H +#define MACHINELEARNINGTASK_H + +#include <opencv2/core.hpp> + +#include "pipeline/tasks/PipelineTask.h" + +namespace grndr +{ + class ImageTags; + + struct MachineLearningTaskData + { + cv::Mat imageData; + cv::Mat imageTagsData; + }; + + class MachineLearningTask : public PipelineTask<MachineLearningTaskData> + { + public: + using PipelineTask<MachineLearningTaskData>::PipelineTask; + }; +} + +#endif diff --git a/Grinder/pipeline/tasks/PipelineTask.h b/Grinder/pipeline/tasks/PipelineTask.h new file mode 100644 index 0000000000000000000000000000000000000000..4e9980f1b24d81f2f7267d3c896ee9f4bf9e029d --- /dev/null +++ b/Grinder/pipeline/tasks/PipelineTask.h @@ -0,0 +1,31 @@ +/****************************************************************************** + * File: PipelineTask.h + * Date: 27.8.2019 + *****************************************************************************/ + +#ifndef PIPELINETASK_H +#define PIPELINETASK_H + +#include "task/Task.h" + +namespace grndr +{ + class EngineExecutionContext; + + template<typename DataType> + class PipelineTask : public Task + { + public: + using data_type = DataType; + + public: + using Task::Task; + + public: + virtual void processEngineStart(EngineExecutionContext& ctx, const data_type& data) { Q_UNUSED(ctx); Q_UNUSED(data); } + virtual void processEnginePass(EngineExecutionContext& ctx, const data_type& data) { Q_UNUSED(ctx); Q_UNUSED(data); } + virtual void processEngineEnd(EngineExecutionContext& ctx, const data_type& data) { Q_UNUSED(ctx); Q_UNUSED(data); } + }; +} + +#endif diff --git a/Grinder/project/exporters/HDF5File.cpp b/Grinder/project/exporters/HDF5Export.cpp similarity index 84% rename from Grinder/project/exporters/HDF5File.cpp rename to Grinder/project/exporters/HDF5Export.cpp index 7863a878048b64f89fe3c9af6f4bfc02f1a65970..c712faa55fde84daedead56ab21c86456fbb44e7 100644 --- a/Grinder/project/exporters/HDF5File.cpp +++ b/Grinder/project/exporters/HDF5Export.cpp @@ -1,10 +1,10 @@ /****************************************************************************** - * File: HDF5File.cpp + * File: HDF5Export.cpp * Date: 26.8.2019 *****************************************************************************/ #include "Grinder.h" -#include "HDF5File.h" +#include "HDF5Export.h" #include "project/ProjectExceptions.h" #include <opencv2/imgproc.hpp> @@ -12,19 +12,23 @@ #define HDF5_DATASET_DATA "data" #define HDF5_DATASET_TAGS "label" -HDF5File::HDF5File(QString filename, bool truncate) +HDF5Export::HDF5Export(QString filename, bool truncate) : + _filename{filename} { H5::Exception::dontPrint(); try { + QFileInfo fi{filename}; + fi.dir().mkpath(fi.dir().path()); + // Create/open the H5 file - _h5File = H5::H5File{filename.toLatin1(), truncate ? H5F_ACC_TRUNC : 0}; + _h5File = H5::H5File{_filename.toLatin1(), truncate ? H5F_ACC_TRUNC : 0}; } catch (H5::Exception& e) { throwH5Exception(e); } } -void HDF5File::initExport(QSize imageSize, unsigned int imageCount, unsigned int tagsCount, ExportFlags flags) +void HDF5Export::initExport(QSize imageSize, unsigned int imageCount, unsigned int tagsCount, ExportFlags flags) { if (imageSize.isNull()) throw ExportException{nullptr, _EXCPT("Image size may not be 0x0")}; @@ -59,7 +63,7 @@ void HDF5File::initExport(QSize imageSize, unsigned int imageCount, unsigned int } } -void HDF5File::exportImageEx(const cv::Mat& image, const std::vector<cv::Mat>& tagMatrices) const +void HDF5Export::exportImageEx(const cv::Mat& image, const std::vector<cv::Mat>& tagMatrices) const { // Verify parameters if (static_cast<hsize_t>(image.rows) != _imageSize.second || static_cast<hsize_t>(image.cols) != _imageSize.first) @@ -113,7 +117,7 @@ void HDF5File::exportImageEx(const cv::Mat& image, const std::vector<cv::Mat>& t _currentImage += 1; } -void HDF5File::exportImageTags(const std::vector<cv::Mat>& tagMatrices) const +void HDF5Export::exportImageTags(const std::vector<cv::Mat>& tagMatrices) const { if (tagMatrices.size() > _tagsCount) throw ExportException{nullptr, _EXCPT(QString{"Tried to export more than %1 image tag(s)"}.arg(_tagsCount))}; @@ -140,14 +144,14 @@ void HDF5File::exportImageTags(const std::vector<cv::Mat>& tagMatrices) const } } -H5::DataSpace HDF5File::createDataSpace(unsigned int channels) const +H5::DataSpace HDF5Export::createDataSpace(unsigned int channels) const { // Our dataspace is 4D: Index, number of channels, Y and X hsize_t dataSpaceDims[] = {_imageCount, channels, _imageSize.second, _imageSize.first}; return H5::DataSpace{4, dataSpaceDims}; } -H5::DataSet HDF5File::createDataSet(const H5::DataSpace& dataSpace, QString name, H5::PredType predType) const +H5::DataSet HDF5Export::createDataSet(const H5::DataSpace& dataSpace, QString name, H5::PredType predType) const { H5::DSetCreatPropList propList{H5::DSetCreatPropList::DEFAULT}; @@ -161,7 +165,7 @@ H5::DataSet HDF5File::createDataSet(const H5::DataSpace& dataSpace, QString name return _h5File.createDataSet(name.toLatin1(), predType, dataSpace, propList); } -void HDF5File::throwH5Exception(H5::Exception& e) const +void HDF5Export::throwH5Exception(H5::Exception& e) const { throw ExportException{nullptr, _EXCPT(QString{"HDF5 error: %1"}.arg(e.getDetailMsg().data()))}; } diff --git a/Grinder/project/exporters/HDF5File.h b/Grinder/project/exporters/HDF5Export.h similarity index 79% rename from Grinder/project/exporters/HDF5File.h rename to Grinder/project/exporters/HDF5Export.h index 95ea8d3ff5bb94d0d5ee92d6808ebdaca9597249..4d6777ac4749cd7d0d5101305983745d4a35e855 100644 --- a/Grinder/project/exporters/HDF5File.h +++ b/Grinder/project/exporters/HDF5Export.h @@ -1,19 +1,20 @@ /****************************************************************************** - * File: HDF5File.h + * File: HDF5Export.h * Date: 26.8.2019 *****************************************************************************/ -#ifndef HDF5FILE_H -#define HDF5FILE_H +#ifndef HDF5EXPORT_H +#define HDF5EXPORT_H #include <H5Cpp.h> #include <opencv2/core.hpp> #include <QSize> +#include <QString> namespace grndr { - class HDF5File + class HDF5Export { public: enum class ExportFlag : unsigned int @@ -29,7 +30,7 @@ namespace grndr Q_DECLARE_FLAGS(ExportFlags, ExportFlag) public: - HDF5File(QString filename, bool truncate = true); + HDF5Export(QString filename, bool truncate = true); public: void initExport(QSize imageSize, unsigned int imageCount, unsigned int tagsCount, ExportFlags flags = ExportFlag::ExportTags); @@ -37,6 +38,9 @@ namespace grndr void exportImage(const cv::Mat& image) const { exportImageEx(image, {}); } void exportImageEx(const cv::Mat& image, const std::vector<cv::Mat>& tagMatrices) const; + public: + QString getFilename() const { return _filename; } + private: void exportImageTags(const std::vector<cv::Mat>& tagMatrices) const; @@ -48,6 +52,8 @@ namespace grndr void throwH5Exception(H5::Exception& e) const; private: + QString _filename{""}; + std::pair<hsize_t, hsize_t> _imageSize; unsigned int _imageCount{0}; unsigned int _tagsCount{0}; @@ -67,6 +73,6 @@ namespace grndr }; } -Q_DECLARE_OPERATORS_FOR_FLAGS(grndr::HDF5File::ExportFlags) +Q_DECLARE_OPERATORS_FOR_FLAGS(grndr::HDF5Export::ExportFlags) #endif diff --git a/Grinder/project/exporters/HDF5Exporter.cpp b/Grinder/project/exporters/HDF5Exporter.cpp index c697e189c354fc27627cd9952324b27531abe008..bcfe29afd9ebfcb924cc0e2f32c815e619bb1487 100644 --- a/Grinder/project/exporters/HDF5Exporter.cpp +++ b/Grinder/project/exporters/HDF5Exporter.cpp @@ -11,7 +11,7 @@ #include "image/ImageTags.h" #include "ui/dlg/HDF5ExportDialog.h" -HDF5Exporter::HDF5Exporter(Label* label, const Block* canvasBlock, const ImageReferenceSelection& imageReferences, HDF5File::ExportFlags exportFlags) : ProjectExporter("HDF5 Exporter", "*.h5;*.hdf5"), +HDF5Exporter::HDF5Exporter(Label* label, const Block* canvasBlock, const ImageReferenceSelection& imageReferences, HDF5Export::ExportFlags exportFlags) : ProjectExporter("HDF5 Exporter", "*.h5;*.hdf5"), _label{label}, _canvasBlock{canvasBlock}, _imageReferences{imageReferences}, _exportFlags{exportFlags} { @@ -30,13 +30,13 @@ bool HDF5Exporter::invokeUi(const Project* project, QWidget* parent) _canvasBlock = dlg.getCanvasBlock(); _imageReferences = dlg.getImageReferences(); - _exportFlags = HDF5File::ExportFlag::None; + _exportFlags = HDF5Export::ExportFlag::None; if (dlg.exportAsGrayscale()) - _exportFlags |= HDF5File::ExportFlag::ExportAsGrayscale; + _exportFlags |= HDF5Export::ExportFlag::ExportAsGrayscale; if (dlg.exportImageTags()) - _exportFlags |= HDF5File::ExportFlag::ExportTags; + _exportFlags |= HDF5Export::ExportFlag::ExportTags; return true; } @@ -54,7 +54,7 @@ void HDF5Exporter::exportProject(const Project* project, QString fileName) verifyImageReferences(project); // Prepare the H5 file - HDF5File h5File{fileName}; + HDF5Export h5File{fileName}; h5File.initExport(getImageSize(), _imageReferences.size(), getImageTagsCount(), _exportFlags); // Export images; tags will also be exported if necessary @@ -99,7 +99,7 @@ unsigned int HDF5Exporter::getImageTagsCount() const return tagsCount; } -void HDF5Exporter::exportImageReferences(const Project* project, const HDF5File& h5File) const +void HDF5Exporter::exportImageReferences(const Project* project, const HDF5Export& h5File) const { LongOperation opExportImages{"Exporting images", static_cast<unsigned int>(_imageReferences.size())}; @@ -107,7 +107,7 @@ void HDF5Exporter::exportImageReferences(const Project* project, const HDF5File& exportImageReference(project, imgRef, h5File); } -void HDF5Exporter::exportImageReference(const Project* project, const ImageReference* imgRef, const HDF5File& h5File) const +void HDF5Exporter::exportImageReference(const Project* project, const ImageReference* imgRef, const HDF5Export& h5File) const { LongOperationStep opExportImage{imgRef->getImageFilePath()}; cv::Mat imgData; @@ -131,13 +131,13 @@ void HDF5Exporter::exportImageReference(const Project* project, const ImageRefer // Generate the tag matrices std::vector<cv::Mat> tagMatrices; - if (_exportFlags.testFlag(HDF5File::ExportFlag::ExportTags)) + if (_exportFlags.testFlag(HDF5Export::ExportFlag::ExportTags)) { auto imageTagsBitmap = grinder()->engineController().generateImageTagsBitmap(_label, _canvasBlock, imgRef, false); if (imageTagsBitmap.isValid()) { - if (_exportFlags.testFlag(HDF5File::ExportFlag::MergeTags)) + if (_exportFlags.testFlag(HDF5Export::ExportFlag::MergeTags)) tagMatrices = exportImageTags_Merged(imageTagsBitmap, getImageTagsCount()); else tagMatrices = exportImageTags_Individually(imageTagsBitmap, getImageTagsCount()); diff --git a/Grinder/project/exporters/HDF5Exporter.h b/Grinder/project/exporters/HDF5Exporter.h index db42237dfb1f50bdd528181e233a067832ccd5c7..09e51642fa0ff914430feb25250d6e40365af1ec 100644 --- a/Grinder/project/exporters/HDF5Exporter.h +++ b/Grinder/project/exporters/HDF5Exporter.h @@ -8,7 +8,7 @@ #include <opencv2/core.hpp> -#include "HDF5File.h" +#include "HDF5Export.h" #include "project/ProjectExporter.h" #include "project/ImageReferenceSelection.h" @@ -22,7 +22,7 @@ namespace grndr class HDF5Exporter : public ProjectExporter { public: - HDF5Exporter(Label* label = nullptr, const Block* canvasBlock = nullptr, const ImageReferenceSelection& imageReferences = {}, HDF5File::ExportFlags exportFlags = HDF5File::ExportFlag::ExportTags); + HDF5Exporter(Label* label = nullptr, const Block* canvasBlock = nullptr, const ImageReferenceSelection& imageReferences = {}, HDF5Export::ExportFlags exportFlags = HDF5Export::ExportFlag::ExportTags); public: virtual bool invokeUi(const Project* project, QWidget* parent) override; @@ -35,8 +35,8 @@ namespace grndr unsigned int getImageTagsCount() const; private: - void exportImageReferences(const Project* project, const HDF5File& h5File) const; - void exportImageReference(const Project* project, const ImageReference* imgRef, const HDF5File& h5File) const; + void exportImageReferences(const Project* project, const HDF5Export& h5File) const; + void exportImageReference(const Project* project, const ImageReference* imgRef, const HDF5Export& h5File) const; void exportImageTags(const Project* project, H5::H5File& h5File) const; void exportImageTags(const Project* project, const ImageReference* imgRef, unsigned int tagCount, unsigned int index, const H5::DataSpace& dataSpace, const H5::DataSet& dataSet) const; @@ -48,7 +48,7 @@ namespace grndr const Block* _canvasBlock{nullptr}; ImageReferenceSelection _imageReferences; - HDF5File::ExportFlags _exportFlags{HDF5File::ExportFlag::None}; + HDF5Export::ExportFlags _exportFlags{HDF5Export::ExportFlag::None}; }; } diff --git a/Grinder/task/Task.cpp b/Grinder/task/Task.cpp index ddd74a9f9321abd8902eaaa0475dac570c736da0..1fc32d86b245905a610242601934f96e0139c836 100644 --- a/Grinder/task/Task.cpp +++ b/Grinder/task/Task.cpp @@ -1,180 +1,180 @@ -/****************************************************************************** - * File: Task.cpp - * Date: 31.10.2018 - *****************************************************************************/ - -#include "Grinder.h" -#include "Task.h" - -const char* Task::Serialization_Value_Type = "Type"; -const char* Task::Serialization_Value_Name = "Name"; - -Task::Task(TaskPool* taskPool, TaskType type, Capabilities caps, QString name) : - _taskPool{taskPool}, _type{type}, _capabilities{caps}, _name{name} -{ - if (!taskPool) - throw std::invalid_argument{_EXCPT("taskPool may not be null")}; -} - -void Task::startTask() -{ - if (!isRunning()) - { - setResult(Result::None); - _isPaused = false; - _isStopped = false; - _message = ""; - _messageLog.clear(); - - try { - reportStatus(QString{"Starting task '%1'..."}.arg(_name), false, true); - - verifyTask(); - execute(); - setStatus(Status::Running); - - emit taskStarted(); - emit taskUpdated(); - } catch (std::exception& e) { - // Executing the task failed - reportError("execute", GetExceptionMessage(e.what())); - setResult(Result::Failed); - throw; - } - } -} - -void Task::pauseTask(bool setPause) -{ - if (isRunning() && _capabilities.testFlag(Capability::CanBePaused)) - { - if (setPause != _isPaused) - { - try { - reportStatus(QString{"%1 task '%2'"}.arg(setPause ? "Pausing" : "Unpausing").arg(_name), true, true); - - pause(setPause); - _isPaused = setPause; - - emit taskPaused(setPause); - emit taskUpdated(); - } catch (std::exception& e) { - // Pausing/Unpausing the task failed - reportError(setPause ? "pause" : "unpause", GetExceptionMessage(e.what())); - throw; - } - } - } -} - -void Task::refreshTask() -{ - if (isRunning() && _capabilities.testFlag(Capability::CanBeRefreshed)) - { - try { - reportStatus(QString{"Refreshing task '%1'..."}.arg(_name), true); - - refresh(); - - emit taskRefreshed(); - emit taskUpdated(); - } catch (std::exception& e) { - // Refreshing the task failed - reportError("refresh", GetExceptionMessage(e.what())); - throw; - } - } -} - -void Task::stopTask() -{ - if (isRunning() && _capabilities.testFlag(Capability::CanBeStopped)) - { - try { - reportStatus(QString{"Stopping task '%1'..."}.arg(_name), true); - - stop(); - _isStopped = true; - - emit taskStopped(); - emit taskUpdated(); - } catch (std::exception& e) { - // Stopping the task failed - reportError("stop", GetExceptionMessage(e.what())); - throw; - } - } -} - -void Task::finishTask(bool succeeded) -{ - if (isRunning()) - { - // Reset the task, keeping only the failed status - setStatus(Status::Pending); - - if (!_isStopped) - { - setResult(succeeded ? Result::Succeeded : Result::Failed); - reportStatus(QString{"Task '%1' %2"}.arg(_name).arg(succeeded ? "succeeded" : "failed"), true); - } - else - { - setResult(Result::None); - reportStatus(QString{"Task '%1' stopped"}.arg(_name), true); - } - - _isPaused = false; - - finish(succeeded); - - emit taskFinished(); - emit taskUpdated(); - } -} - -void Task::updateTask() -{ - if (isRunning()) - { - try { - update(); - } catch (std::exception& e) { - // Updating the task failed - reportError("update", GetExceptionMessage(e.what())); - finishTask(false); - throw; - } - } -} - -void Task::serialize(SerializationContext& ctx) const -{ - // Serialize values - ctx.settings()(Serialization_Value_Type) = _type; - ctx.settings()(Serialization_Value_Name) = _name; -} - -void Task::deserialize(DeserializationContext& ctx) -{ - // Deserialize values - _name = ctx.settings()(Serialization_Value_Name).toString(); -} - -void Task::reportStatus(QString status, bool preLine, bool postLine) -{ - auto addLine = [this]() { addLogMessage("---------------------------------------------------------------------", false); }; - - if (preLine) - addLine(); - - addLogMessage(status); - - if (postLine) - addLine(); -} - -void Task::reportError(QString action, QString reason) -{ - addLogMessage(QString{"Failed to %1 task '%2': %3"}.arg(action).arg(_name).arg(reason)); -} +/****************************************************************************** + * File: Task.cpp + * Date: 31.10.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "Task.h" + +const char* Task::Serialization_Value_Type = "Type"; +const char* Task::Serialization_Value_Name = "Name"; + +Task::Task(TaskPool* taskPool, TaskType type, Capabilities caps, QString name) : + _taskPool{taskPool}, _type{type}, _capabilities{caps}, _name{name} +{ + if (!taskPool) + throw std::invalid_argument{_EXCPT("taskPool may not be null")}; +} + +void Task::startTask() +{ + if (!isRunning()) + { + setResult(Result::None); + _isPaused = false; + _isStopped = false; + _message = ""; + _messageLog.clear(); + + try { + reportStatus(QString{"Starting task '%1'..."}.arg(_name), false, true); + + verifyTask(); + execute(); + setStatus(Status::Running); + + emit taskStarted(); + emit taskUpdated(); + } catch (std::exception& e) { + // Executing the task failed + reportError("execute", GetExceptionMessage(e.what())); + setResult(Result::Failed); + throw; + } + } +} + +void Task::pauseTask(bool setPause) +{ + if (isRunning() && _capabilities.testFlag(Capability::CanBePaused)) + { + if (setPause != _isPaused) + { + try { + reportStatus(QString{"%1 task '%2'"}.arg(setPause ? "Pausing" : "Unpausing").arg(_name), true, true); + + pause(setPause); + _isPaused = setPause; + + emit taskPaused(setPause); + emit taskUpdated(); + } catch (std::exception& e) { + // Pausing/Unpausing the task failed + reportError(setPause ? "pause" : "unpause", GetExceptionMessage(e.what())); + throw; + } + } + } +} + +void Task::refreshTask() +{ + if (isRunning() && _capabilities.testFlag(Capability::CanBeRefreshed)) + { + try { + reportStatus(QString{"Refreshing task '%1'..."}.arg(_name), true); + + refresh(); + + emit taskRefreshed(); + emit taskUpdated(); + } catch (std::exception& e) { + // Refreshing the task failed + reportError("refresh", GetExceptionMessage(e.what())); + throw; + } + } +} + +void Task::stopTask() +{ + if (isRunning() && _capabilities.testFlag(Capability::CanBeStopped)) + { + try { + reportStatus(QString{"Stopping task '%1'..."}.arg(_name), true); + + stop(); + _isStopped = true; + + emit taskStopped(); + emit taskUpdated(); + } catch (std::exception& e) { + // Stopping the task failed + reportError("stop", GetExceptionMessage(e.what())); + throw; + } + } +} + +void Task::finishTask(bool succeeded) +{ + if (isRunning()) + { + // Reset the task, keeping only the failed status + setStatus(Status::Pending); + + if (!_isStopped) + { + setResult(succeeded ? Result::Succeeded : Result::Failed); + reportStatus(QString{"Task '%1' %2"}.arg(_name).arg(succeeded ? "succeeded" : "failed"), true); + } + else + { + setResult(Result::None); + reportStatus(QString{"Task '%1' stopped"}.arg(_name), true); + } + + _isPaused = false; + + finish(succeeded); + + emit taskFinished(); + emit taskUpdated(); + } +} + +void Task::updateTask() +{ + if (isRunning()) + { + try { + update(); + } catch (std::exception& e) { + // Updating the task failed + reportError("update", GetExceptionMessage(e.what())); + finishTask(false); + throw; + } + } +} + +void Task::serialize(SerializationContext& ctx) const +{ + // Serialize values + ctx.settings()(Serialization_Value_Type) = _type; + ctx.settings()(Serialization_Value_Name) = _name; +} + +void Task::deserialize(DeserializationContext& ctx) +{ + // Deserialize values + _name = ctx.settings()(Serialization_Value_Name).toString(); +} + +void Task::reportStatus(QString status, bool preLine, bool postLine) +{ + auto addLine = [this]() { addLogMessage("---------------------------------------------------------------------", false); }; + + if (preLine) + addLine(); + + addLogMessage(status); + + if (postLine) + addLine(); +} + +void Task::reportError(QString action, QString reason) +{ + addLogMessage(QString{"Failed to %1 task '%2': %3"}.arg(action).arg(_name).arg(reason)); +} diff --git a/Grinder/task/Task.h b/Grinder/task/Task.h index cf462fe43b9d54bebde25b3849984d906fb7701a..102eba0fe61dc2dd452dc5bd7517a642ded942b3 100644 --- a/Grinder/task/Task.h +++ b/Grinder/task/Task.h @@ -1,143 +1,143 @@ -/****************************************************************************** - * File: Task.h - * Date: 31.10.2018 - *****************************************************************************/ - -#ifndef TASK_H -#define TASK_H - -#include "TaskType.h" -#include "common/serialization/SerializationContext.h" -#include "common/serialization/DeserializationContext.h" - -namespace grndr -{ - class TaskPool; - class ConfigureTaskWidgetBase; - - class Task : public QObject - { - Q_OBJECT - - public: - static const char* Serialization_Value_Type; - static const char* Serialization_Value_Name; - - public: - enum class Status - { - Pending, - Running, - }; - - enum class Result - { - None, - Succeeded, - Failed, - }; - - enum class Capability : unsigned int - { - None = 0x0000, - - CanBePaused = 0x0001, - CanBeStopped = 0x0002, - CanBeRefreshed = 0x0004, - - HasProgress = 0x0008, - - All = 0xFFFF, - }; - - Q_DECLARE_FLAGS(Capabilities, Capability) - - public: - Task(TaskPool* taskPool, TaskType type, Capabilities caps, QString name = ""); - - public: - virtual void initTask() { } - - public: - void startTask(); - void pauseTask(bool setPause = true); - void refreshTask(); - void stopTask(); - void finishTask(bool succeeded); - - void updateTask(); - - public: - virtual ConfigureTaskWidgetBase* createEditor(bool newTask, QWidget* parent = nullptr) { Q_UNUSED(newTask); Q_UNUSED(parent); return nullptr; } - - public: - TaskPool* taskPool() { return _taskPool; } - const TaskPool* taskPool() const { return _taskPool; } - - TaskType getType() const { return _type; } - Capabilities getCapabilities() const { return _capabilities; } - QString getName() const { return _name; } - void setName(QString name) { if (_name != name) { _name = name; emit taskUpdated(); } } - - Status getStatus() const { return _status; } - void setStatus(Status status) { if (_status != status) { _status = status; emit taskUpdated(); } } - Result getResult() const { return _result; } - void setResult(Result result) { if (_result != result) { _result = result; emit taskUpdated(); } } - bool isRunning() const { return _status == Status::Running; } - bool isPaused() const { return _isPaused; } - float getProgress() const { return _progress; } - void setProgress(float progress) { if (_progress != progress) { _progress = progress; emit taskUpdated(); } } - QString getMessage() const { return _message; } - void setMessage(QString message) { if (_message != message) { _message = message; emit taskUpdated(); } } - QStringList getMessageLog() const { return _messageLog; } - void addLogMessage(QString message, bool setMsg = true) { _messageLog << message; if (setMsg) setMessage(message); else emit taskUpdated(); } - void clearMessageLog() { _messageLog.clear(); emit taskUpdated(); } - - public: - virtual void serialize(SerializationContext& ctx) const; - virtual void deserialize(DeserializationContext& ctx); - - signals: - void taskStarted(); - void taskPaused(bool); - void taskRefreshed(); - void taskStopped(); - void taskFinished(); - - void taskUpdated(); - - protected: - virtual void verifyTask() const { }; - - virtual void execute() { setProgress(0.0f); } - virtual void pause(bool setPause) { Q_UNUSED(setPause); } - virtual void refresh() { } - virtual void stop() { setProgress(0.0f); } - - virtual void update() { } - virtual void finish(bool succeeded) { Q_UNUSED(succeeded); } - - protected: - void reportStatus(QString status, bool preLine = false, bool postLine = false); - void reportError(QString action, QString reason); - - protected: - TaskPool* _taskPool{nullptr}; - - TaskType _type{TaskType::Undefined}; - Capabilities _capabilities{Capability::None}; - QString _name; - - Status _status{Status::Pending}; - Result _result{Result::None}; - bool _isPaused{false}; - bool _isStopped{false}; - float _progress{0.0f}; - QString _message; - QStringList _messageLog; - }; -} - -Q_DECLARE_OPERATORS_FOR_FLAGS(grndr::Task::Capabilities) - -#endif +/****************************************************************************** + * File: Task.h + * Date: 31.10.2018 + *****************************************************************************/ + +#ifndef TASK_H +#define TASK_H + +#include "TaskType.h" +#include "common/serialization/SerializationContext.h" +#include "common/serialization/DeserializationContext.h" + +namespace grndr +{ + class TaskPool; + class ConfigureTaskWidgetBase; + + class Task : public QObject + { + Q_OBJECT + + public: + static const char* Serialization_Value_Type; + static const char* Serialization_Value_Name; + + public: + enum class Status + { + Pending, + Running, + }; + + enum class Result + { + None, + Succeeded, + Failed, + }; + + enum class Capability : unsigned int + { + None = 0x0000, + + CanBePaused = 0x0001, + CanBeStopped = 0x0002, + CanBeRefreshed = 0x0004, + + HasProgress = 0x0008, + + All = 0xFFFF, + }; + + Q_DECLARE_FLAGS(Capabilities, Capability) + + public: + Task(TaskPool* taskPool, TaskType type, Capabilities caps, QString name = ""); + + public: + virtual void initTask() { } + + public: + void startTask(); + void pauseTask(bool setPause = true); + void refreshTask(); + void stopTask(); + void finishTask(bool succeeded); + + void updateTask(); + + public: + virtual ConfigureTaskWidgetBase* createEditor(bool newTask, QWidget* parent = nullptr) { Q_UNUSED(newTask); Q_UNUSED(parent); return nullptr; } + + public: + TaskPool* taskPool() { return _taskPool; } + const TaskPool* taskPool() const { return _taskPool; } + + TaskType getType() const { return _type; } + Capabilities getCapabilities() const { return _capabilities; } + QString getName() const { return _name; } + void setName(QString name) { if (_name != name) { _name = name; emit taskUpdated(); } } + + Status getStatus() const { return _status; } + void setStatus(Status status) { if (_status != status) { _status = status; emit taskUpdated(); } } + Result getResult() const { return _result; } + void setResult(Result result) { if (_result != result) { _result = result; emit taskUpdated(); } } + bool isRunning() const { return _status == Status::Running; } + bool isPaused() const { return _isPaused; } + float getProgress() const { return _progress; } + void setProgress(float progress) { if (_progress != progress) { _progress = progress; emit taskUpdated(); } } + QString getMessage() const { return _message; } + void setMessage(QString message) { if (_message != message) { _message = message; emit taskUpdated(); } } + QStringList getMessageLog() const { return _messageLog; } + void addLogMessage(QString message, bool setMsg = true) { _messageLog << message; if (setMsg) setMessage(message); else emit taskUpdated(); } + void clearMessageLog() { _messageLog.clear(); emit taskUpdated(); } + + public: + virtual void serialize(SerializationContext& ctx) const; + virtual void deserialize(DeserializationContext& ctx); + + signals: + void taskStarted(); + void taskPaused(bool); + void taskRefreshed(); + void taskStopped(); + void taskFinished(); + + void taskUpdated(); + + protected: + virtual void verifyTask() const { }; + + virtual void execute() { setProgress(0.0f); } + virtual void pause(bool setPause) { Q_UNUSED(setPause); } + virtual void refresh() { } + virtual void stop() { setProgress(0.0f); } + + virtual void update() { } + virtual void finish(bool succeeded) { Q_UNUSED(succeeded); } + + protected: + void reportStatus(QString status, bool preLine = false, bool postLine = false); + void reportError(QString action, QString reason); + + protected: + TaskPool* _taskPool{nullptr}; + + TaskType _type{TaskType::Undefined}; + Capabilities _capabilities{Capability::None}; + QString _name; + + Status _status{Status::Pending}; + Result _result{Result::None}; + bool _isPaused{false}; + bool _isStopped{false}; + float _progress{0.0f}; + QString _message; + QStringList _messageLog; + }; +} + +Q_DECLARE_OPERATORS_FOR_FLAGS(grndr::Task::Capabilities) + +#endif diff --git a/Grinder/task/TaskCatalog.h b/Grinder/task/TaskCatalog.h index 47327f55e19e7df7fa70dcba1126f2a207fc586a..3e5de148d5c7d34e7d239ec94c3d754a46bbe3d7 100644 --- a/Grinder/task/TaskCatalog.h +++ b/Grinder/task/TaskCatalog.h @@ -1,45 +1,46 @@ -/****************************************************************************** - * File: TaskCatalog.h - * Date: 31.10.2018 - *****************************************************************************/ - -#ifndef TASKCATALOG_H -#define TASKCATALOG_H - -#include <map> -#include <set> -#include <functional> -#include <memory> - -#include "TaskType.h" - -namespace grndr -{ - class TaskPool; - class Task; - - class TaskCatalog - { - private: - using task_creator_type = std::function<std::unique_ptr<Task>(TaskPool*, QString)>; - - private: - TaskCatalog() { } - - public: - static std::unique_ptr<Task> createTask(TaskPool* taskPool, TaskType type, QString name = ""); - - static void registerStandardTasks(); - - public: - static std::set<TaskType> getTypes(); - - private: - static void registerTaskType(TaskType type, task_creator_type creator); - - private: - static std::map<TaskType, task_creator_type> s_creators; - }; -} - -#endif +/****************************************************************************** + * File: TaskCatalog.h + * Date: 31.10.2018 + *****************************************************************************/ + +#ifndef TASKCATALOG_H +#define TASKCATALOG_H + +#include <map> +#include <set> +#include <functional> +#include <memory> + +#include "TaskType.h" +#include "Task.h" + +namespace grndr +{ + class TaskPool; + class Task; + + class TaskCatalog + { + private: + using task_creator_type = std::function<std::unique_ptr<Task>(TaskPool*, QString)>; + + private: + TaskCatalog() { } + + public: + static std::unique_ptr<Task> createTask(TaskPool* taskPool, TaskType type, QString name = ""); + + static void registerStandardTasks(); + + public: + static std::set<TaskType> getTypes(); + + private: + static void registerTaskType(TaskType type, task_creator_type creator); + + private: + static std::map<TaskType, task_creator_type> s_creators; + }; +} + +#endif diff --git a/Grinder/task/TaskPool.cpp b/Grinder/task/TaskPool.cpp index 6fb7f6c3881cc354558ed9b8710a36033c9a9d23..8e789e18a8ba31718163be036f533d04c404eadd 100644 --- a/Grinder/task/TaskPool.cpp +++ b/Grinder/task/TaskPool.cpp @@ -1,102 +1,102 @@ -/****************************************************************************** - * File: TaskPool.cpp - * Date: 31.10.2018 - *****************************************************************************/ - -#include "Grinder.h" -#include "TaskPool.h" -#include "TaskCatalog.h" -#include "TaskExceptions.h" - -const char* TaskPool::Serialization_Group = "TaskPool"; - -TaskPool::TaskPool(const Project* project) : - _project{project}, _tasks{this} -{ - if (!project) - throw std::invalid_argument{_EXCPT("project may not be null")}; -} - -std::shared_ptr<Task> TaskPool::createTask(TaskType type, QString name, bool addToPool) -{ - if (type == TaskType::Undefined) - throw std::invalid_argument(_EXCPT("type may not be TaskType::Undefined")); - - // Create new task using the task factory; cast the unique ptr to a shared one as well - std::shared_ptr<Task> task = TaskCatalog::createTask(this, type, name); - - try { // Propagate initialization errors to the caller - task->initTask(); - } - catch (...) { - throw; - } - - if (addToPool) - addTask(task); - - return task; -} - -void TaskPool::addTask(std::shared_ptr<Task>& task) -{ - _tasks.push_back(task); - emit taskCreated(task); -} - -void TaskPool::removeTask(const Task* task) -{ - if (task) - { - auto it = _tasks.find(task); - - if (it != _tasks.cend()) - { - // Keep a copy of the shared_ptr holding the task to increase its use count; - // otherwise, the task will be deleted before it has been removed from the vector, potentially causing a crash - auto task = *it; - - emit taskRemoved(*it); - _tasks.erase(it); - } - else - throw TaskPoolException{this, _EXCPT("Tried to remove a task not belonging to the task pool")}; - } -} - -void TaskPool::clear() -{ - // Remove all tasks using the corresponding remove function so that they are removed in a proper manner - while (!_tasks.empty()) - removeTask(_tasks.back().get()); -} - -void TaskPool::updateTasks() -{ - for (auto task : _tasks) - task->updateTask(); -} - -void TaskPool::serialize(SerializationContext& ctx) const -{ - // Serialize all tasks - ctx.beginGroup(TaskVector::Serialization_Group, true); - _tasks.serialize(TaskVector::Serialization_Element, ctx); - ctx.endGroup(); -} - -void TaskPool::deserialize(DeserializationContext& ctx) -{ - // Deserialize all tasks - if (ctx.beginGroup(TaskVector::Serialization_Group)) - { - _tasks.deserialize(TaskVector::Serialization_Element, ctx, [this](const SettingsContainer& settings) { - TaskType type = settings(Task::Serialization_Value_Type, TaskType::Undefined).toString(); - QString name = settings(Task::Serialization_Value_Name).toString(); - - return createTask(type, name); - }); - - ctx.endGroup(); - } -} +/****************************************************************************** + * File: TaskPool.cpp + * Date: 31.10.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "TaskPool.h" +#include "TaskCatalog.h" +#include "TaskExceptions.h" + +const char* TaskPool::Serialization_Group = "TaskPool"; + +TaskPool::TaskPool(const Project* project) : + _project{project}, _tasks{this} +{ + if (!project) + throw std::invalid_argument{_EXCPT("project may not be null")}; +} + +std::shared_ptr<Task> TaskPool::createTask(TaskType type, QString name, bool addToPool) +{ + if (type == TaskType::Undefined) + throw std::invalid_argument(_EXCPT("type may not be TaskType::Undefined")); + + // Create new task using the task factory; cast the unique ptr to a shared one as well + std::shared_ptr<Task> task = TaskCatalog::createTask(this, type, name); + + try { // Propagate initialization errors to the caller + task->initTask(); + } + catch (...) { + throw; + } + + if (addToPool) + addTask(task); + + return task; +} + +void TaskPool::addTask(std::shared_ptr<Task>& task) +{ + _tasks.push_back(task); + emit taskCreated(task); +} + +void TaskPool::removeTask(const Task* task) +{ + if (task) + { + auto it = _tasks.find(task); + + if (it != _tasks.cend()) + { + // Keep a copy of the shared_ptr holding the task to increase its use count; + // otherwise, the task will be deleted before it has been removed from the vector, potentially causing a crash + auto task = *it; + + emit taskRemoved(*it); + _tasks.erase(it); + } + else + throw TaskPoolException{this, _EXCPT("Tried to remove a task not belonging to the task pool")}; + } +} + +void TaskPool::clear() +{ + // Remove all tasks using the corresponding remove function so that they are removed in a proper manner + while (!_tasks.empty()) + removeTask(_tasks.back().get()); +} + +void TaskPool::updateTasks() +{ + for (auto task : _tasks) + task->updateTask(); +} + +void TaskPool::serialize(SerializationContext& ctx) const +{ + // Serialize all tasks + ctx.beginGroup(TaskVector::Serialization_Group, true); + _tasks.serialize(TaskVector::Serialization_Element, ctx); + ctx.endGroup(); +} + +void TaskPool::deserialize(DeserializationContext& ctx) +{ + // Deserialize all tasks + if (ctx.beginGroup(TaskVector::Serialization_Group)) + { + _tasks.deserialize(TaskVector::Serialization_Element, ctx, [this](const SettingsContainer& settings) { + TaskType type = settings(Task::Serialization_Value_Type, TaskType::Undefined).toString(); + QString name = settings(Task::Serialization_Value_Name).toString(); + + return createTask(type, name); + }); + + ctx.endGroup(); + } +} diff --git a/Grinder/task/TaskPool.h b/Grinder/task/TaskPool.h index 63d8f932f82fb0bc15de446044945e4766fc1ee9..834fb01664bb2747c25c6dc12e7fcbbb1dc21dde 100644 --- a/Grinder/task/TaskPool.h +++ b/Grinder/task/TaskPool.h @@ -1,51 +1,51 @@ -/****************************************************************************** - * File: TaskPool.h - * Date: 31.10.2018 - *****************************************************************************/ - -#ifndef TASKPOOL_H -#define TASKPOOL_H - -#include "TaskVector.h" - -namespace grndr -{ - class TaskPool : public QObject - { - Q_OBJECT - - public: - static const char* Serialization_Group; - - public: - TaskPool(const Project* project); - - public: - std::shared_ptr<Task> createTask(TaskType type, QString name = "", bool addToPool = true); - void addTask(std::shared_ptr<Task>& task); - void removeTask(const Task* task); - - void clear(); - - public: - void updateTasks(); - - public: - const TaskVector& tasks() const { return _tasks; } - - public: - void serialize(SerializationContext& ctx) const; - void deserialize(DeserializationContext& ctx); - - signals: - void taskCreated(const std::shared_ptr<Task>&); - void taskRemoved(const std::shared_ptr<Task>&); - - private: - const Project* _project{nullptr}; - - TaskVector _tasks; - }; -} - -#endif +/****************************************************************************** + * File: TaskPool.h + * Date: 31.10.2018 + *****************************************************************************/ + +#ifndef TASKPOOL_H +#define TASKPOOL_H + +#include "TaskVector.h" + +namespace grndr +{ + class TaskPool : public QObject + { + Q_OBJECT + + public: + static const char* Serialization_Group; + + public: + TaskPool(const Project* project); + + public: + std::shared_ptr<Task> createTask(TaskType type, QString name = "", bool addToPool = true); + void addTask(std::shared_ptr<Task>& task); + void removeTask(const Task* task); + + void clear(); + + public: + void updateTasks(); + + public: + const TaskVector& tasks() const { return _tasks; } + + public: + void serialize(SerializationContext& ctx) const; + void deserialize(DeserializationContext& ctx); + + signals: + void taskCreated(const std::shared_ptr<Task>&); + void taskRemoved(const std::shared_ptr<Task>&); + + private: + const Project* _project{nullptr}; + + TaskVector _tasks; + }; +} + +#endif diff --git a/Grinder/task/tasks/GenericTask.cpp b/Grinder/task/tasks/GenericTask.cpp index 1786f7d193be49eaca1a6c08707e9685281710a2..8f57977d89fb95a40bb0853ff47518d0b6637d55 100644 --- a/Grinder/task/tasks/GenericTask.cpp +++ b/Grinder/task/tasks/GenericTask.cpp @@ -1,96 +1,96 @@ -/****************************************************************************** - * File: GenericTask.cpp - * Date: 31.10.2018 - *****************************************************************************/ - -#include "Grinder.h" -#include "GenericTask.h" -#include "task/TaskExceptions.h" -#include "ui/task/tasks/GenericTaskWidget.h" - -const TaskType GenericTask::type_value = TaskType::Generic; - -const char* GenericTask::Serialization_Value_Command = "Command"; -const char* GenericTask::Serialization_Value_Arguments = "Arguments"; - -GenericTask::GenericTask(TaskPool* taskPool, QString name) : Task(taskPool, type_value, Task::Capability::CanBeStopped, name) -{ - -} - -ConfigureTaskWidgetBase* GenericTask::createEditor(bool newTask, QWidget* parent) -{ - return new GenericTaskWidget{this, newTask, parent}; -} - -void GenericTask::serialize(SerializationContext& ctx) const -{ - Task::serialize(ctx); - - // Serialize values - ctx.settings()(Serialization_Value_Command) = _command; - ctx.settings()(Serialization_Value_Arguments) = _arguments.join("\n"); -} - -void GenericTask::deserialize(DeserializationContext& ctx) -{ - Task::deserialize(ctx); - - // Deserialize values - _command = ctx.settings()(Serialization_Value_Command).toString(); - _arguments = ctx.settings()(Serialization_Value_Arguments).toString().split("\n"); -} - -void GenericTask::verifyTask() const -{ - Task::verifyTask(); - - if (_command.isEmpty()) - throw TaskException{this, _EXCPT("No command to execute provided")}; - - if (_process) - throw TaskException{this, _EXCPT("The task is already running")}; -} - -void GenericTask::execute() -{ - // Create a new process and start it - reportStatus(QString{"> Executing \"%1%2%3\""}.arg(_command).arg(!_arguments.empty() ? " " : "").arg(_arguments.join(" ")), false, true); - - _process = std::make_unique<QProcess>(); - _process->setProcessChannelMode(QProcess::MergedChannels); - _process->start(_command, _arguments, QIODevice::ReadOnly); - - if (_process->waitForStarted(-1)) - { - connect(_process.get(), QOverload<int, QProcess::ExitStatus>::of(&QProcess::finished), this, &GenericTask::processFinished); - connect(_process.get(), &QProcess::readyReadStandardOutput, this, &GenericTask::outputAvailable); - } - else - throw TaskException{this, _EXCPT("Unable to start the process")}; -} - -void GenericTask::stop() -{ - if (_process && _process->state() == QProcess::Running) - _process->terminate(); -} - -void GenericTask::processFinished(int exitCode, QProcess::ExitStatus exitStatus) -{ - finishTask(exitCode == 0 && exitStatus == QProcess::NormalExit); - - // Destroy the process - _process = nullptr; -} - -void GenericTask::outputAvailable() -{ - if (_process && _logOutput) - { - QString output = _process->readAllStandardOutput().toStdString().data(); - - for (auto line : output.split("\n")) - addLogMessage(line); - } -} +/****************************************************************************** + * File: GenericTask.cpp + * Date: 31.10.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "GenericTask.h" +#include "task/TaskExceptions.h" +#include "ui/task/tasks/GenericTaskWidget.h" + +const TaskType GenericTask::type_value = TaskType::Generic; + +const char* GenericTask::Serialization_Value_Command = "Command"; +const char* GenericTask::Serialization_Value_Arguments = "Arguments"; + +GenericTask::GenericTask(TaskPool* taskPool, QString name) : Task(taskPool, type_value, Task::Capability::CanBeStopped, name) +{ + +} + +ConfigureTaskWidgetBase* GenericTask::createEditor(bool newTask, QWidget* parent) +{ + return new GenericTaskWidget{this, newTask, parent}; +} + +void GenericTask::serialize(SerializationContext& ctx) const +{ + Task::serialize(ctx); + + // Serialize values + ctx.settings()(Serialization_Value_Command) = _command; + ctx.settings()(Serialization_Value_Arguments) = _arguments.join("\n"); +} + +void GenericTask::deserialize(DeserializationContext& ctx) +{ + Task::deserialize(ctx); + + // Deserialize values + _command = ctx.settings()(Serialization_Value_Command).toString(); + _arguments = ctx.settings()(Serialization_Value_Arguments).toString().split("\n"); +} + +void GenericTask::verifyTask() const +{ + Task::verifyTask(); + + if (_command.isEmpty()) + throw TaskException{this, _EXCPT("No command to execute provided")}; + + if (_process) + throw TaskException{this, _EXCPT("The task is already running")}; +} + +void GenericTask::execute() +{ + // Create a new process and start it + reportStatus(QString{"> Executing \"%1%2%3\""}.arg(_command).arg(!_arguments.empty() ? " " : "").arg(_arguments.join(" ")), false, true); + + _process = std::make_unique<QProcess>(); + _process->setProcessChannelMode(QProcess::MergedChannels); + _process->start(_command, _arguments, QIODevice::ReadOnly); + + if (_process->waitForStarted(-1)) + { + connect(_process.get(), QOverload<int, QProcess::ExitStatus>::of(&QProcess::finished), this, &GenericTask::processFinished); + connect(_process.get(), &QProcess::readyReadStandardOutput, this, &GenericTask::outputAvailable); + } + else + throw TaskException{this, _EXCPT("Unable to start the process")}; +} + +void GenericTask::stop() +{ + if (_process && _process->state() == QProcess::Running) + _process->terminate(); +} + +void GenericTask::processFinished(int exitCode, QProcess::ExitStatus exitStatus) +{ + finishTask(exitCode == 0 && exitStatus == QProcess::NormalExit); + + // Destroy the process + _process = nullptr; +} + +void GenericTask::outputAvailable() +{ + if (_process && _logOutput) + { + QString output = _process->readAllStandardOutput().toStdString().data(); + + for (auto line : output.split("\n")) + addLogMessage(line); + } +} diff --git a/Grinder/task/tasks/GenericTask.h b/Grinder/task/tasks/GenericTask.h index 18736ad5d7e2cee94d42763b00a91ed58cdf788d..f12832b3042d8f83c010ee44b5a9d966940d98db 100644 --- a/Grinder/task/tasks/GenericTask.h +++ b/Grinder/task/tasks/GenericTask.h @@ -1,64 +1,64 @@ -/****************************************************************************** - * File: GenericTask.h - * Date: 31.10.2018 - *****************************************************************************/ - -#ifndef GENERICTASK_H -#define GENERICTASK_H - -#include "task/Task.h" - -namespace grndr -{ - class GenericTask : public Task - { - Q_OBJECT - - public: - static const TaskType type_value; - - static const char* Serialization_Value_Command; - static const char* Serialization_Value_Arguments; - - public: - GenericTask(TaskPool* taskPool, QString name = ""); - - public: - virtual ConfigureTaskWidgetBase* createEditor(bool newTask, QWidget* parent) override; - - public: - QString getCommand() const { return _command; } - void setCommand(QString cmd) { _command = cmd; } - QStringList& arguments() { return _arguments; } - const QStringList& arguments() const { return _arguments; } - - bool getLogOutput() const { return _logOutput; } - void setLogOutput(bool set = true) { _logOutput = set; } - - public: - virtual void serialize(SerializationContext& ctx) const override; - virtual void deserialize(DeserializationContext& ctx) override; - - protected: - virtual void verifyTask() const override; - - virtual void execute() override; - virtual void stop() override; - - protected: - QString _command; - QStringList _arguments; - - bool _logOutput{true}; - - private slots: - void processFinished(int exitCode, QProcess::ExitStatus exitStatus); - - void outputAvailable(); - - private: - std::unique_ptr<QProcess> _process; - }; -} - -#endif +/****************************************************************************** + * File: GenericTask.h + * Date: 31.10.2018 + *****************************************************************************/ + +#ifndef GENERICTASK_H +#define GENERICTASK_H + +#include "task/Task.h" + +namespace grndr +{ + class GenericTask : public Task + { + Q_OBJECT + + public: + static const TaskType type_value; + + static const char* Serialization_Value_Command; + static const char* Serialization_Value_Arguments; + + public: + GenericTask(TaskPool* taskPool, QString name = ""); + + public: + virtual ConfigureTaskWidgetBase* createEditor(bool newTask, QWidget* parent) override; + + public: + QString getCommand() const { return _command; } + void setCommand(QString cmd) { _command = cmd; } + QStringList& arguments() { return _arguments; } + const QStringList& arguments() const { return _arguments; } + + bool getLogOutput() const { return _logOutput; } + void setLogOutput(bool set = true) { _logOutput = set; } + + public: + virtual void serialize(SerializationContext& ctx) const override; + virtual void deserialize(DeserializationContext& ctx) override; + + protected: + virtual void verifyTask() const override; + + virtual void execute() override; + virtual void stop() override; + + protected: + QString _command; + QStringList _arguments; + + bool _logOutput{true}; + + private slots: + void processFinished(int exitCode, QProcess::ExitStatus exitStatus); + + void outputAvailable(); + + private: + std::unique_ptr<QProcess> _process; + }; +} + +#endif diff --git a/Grinder/ui/barista/tasks/BaristaInferenceTaskWidget.cpp b/Grinder/ui/barista/tasks/BaristaInferenceTaskWidget.cpp index ad6d10e2ddd02848085f485ca857e31ce6eed1fc..f7f9a4e158c2184770e1540ef23f1adcd0de88d4 100644 --- a/Grinder/ui/barista/tasks/BaristaInferenceTaskWidget.cpp +++ b/Grinder/ui/barista/tasks/BaristaInferenceTaskWidget.cpp @@ -33,15 +33,6 @@ void BaristaInferenceTaskWidget::verifySettings() if (ui->lstNetworkState->currentText().isEmpty()) showError("Please select a network state.", ui->lstNetworkState); - if (!getLabel()) - showError("Please select a label.", ui->lstLabel); - - if (!getCanvasBlock()) - showError("Please select a canvas block.", ui->lstCanvasBlock); - - if (getImageReferences().empty()) - showError("Please select at least one image.", ui->lstImageReferences); - if (!ui->chkGenerateItems->isChecked() && !ui->chkGenerateMaps->isChecked()) showError("Please select at least one result type.", ui->chkGenerateItems); } @@ -58,10 +49,6 @@ void BaristaInferenceTaskWidget::applySettings(bool save) _task->setRemoteDirectory(ui->txtRemoteDir->text()); _task->setNetworkStateFile(ui->lstNetworkState->currentText()); - _task->setLabel(getLabel()); - _task->setCanvasBlock(getCanvasBlock()); - _task->setImageReferences(getImageReferences()); - _task->setProbabilityResultTypes(getProbabilityResultTypes()); _task->setProbabilityThreshold(ui->txtThreshold->value()); } @@ -74,10 +61,6 @@ void BaristaInferenceTaskWidget::applySettings(bool save) ui->txtOutputDir->setText(_task->getOutputDirectory()); ui->txtRemoteDir->setText(_task->getRemoteDirectory()); - ui->lstLabel->selectLabel(_task->getLabel()); - ui->lstCanvasBlock->selectCanvasBlock(_task->getCanvasBlock()); - ui->lstImageReferences->selectImageReferences(_task->getImageReferences()); - setProbabilityResultTypes(_task->getProbabilityResultTypes()); ui->txtThreshold->setValue(_task->getProbabilityThreshold()); @@ -85,12 +68,6 @@ void BaristaInferenceTaskWidget::applySettings(bool save) } } -void BaristaInferenceTaskWidget::on_lstLabel_currentIndexChanged(int index) -{ - Q_UNUSED(index); - fillCanvasBlocks(ui->lstLabel->getSelectedLabel()); -} - void BaristaInferenceTaskWidget::on_txtOutputDir_editingFinished() { fillNetworkStateFiles(); @@ -101,9 +78,6 @@ void BaristaInferenceTaskWidget::setupUi() ui->setupUi(this); ui->lstNetworks->populate(); - - fillLabels(grinder()->project().labels()); - fillImageReferences(grinder()->project().imageReferences()); } void BaristaInferenceTaskWidget::fillNetworkStateFiles() @@ -137,21 +111,6 @@ void BaristaInferenceTaskWidget::fillNetworkStateFiles() ui->lstNetworkState->setCurrentIndex(ui->lstNetworkState->count() - 1); } -void BaristaInferenceTaskWidget::fillLabels(const LabelVector& labels) -{ - ui->lstLabel->populate(labels); -} - -void BaristaInferenceTaskWidget::fillCanvasBlocks(const Label* label) -{ - ui->lstCanvasBlock->populate(label); -} - -void BaristaInferenceTaskWidget::fillImageReferences(const ImageReferenceVector& imageRefs) -{ - ui->lstImageReferences->populate(imageRefs); -} - BaristaInferenceTask::ProbabilityResultTypes BaristaInferenceTaskWidget::getProbabilityResultTypes() const { BaristaInferenceTask::ProbabilityResultTypes types = BaristaInferenceTask::ProbabilityResultType::None; @@ -170,18 +129,3 @@ void BaristaInferenceTaskWidget::setProbabilityResultTypes(BaristaInferenceTask: ui->chkGenerateItems->setChecked(types.testFlag(BaristaInferenceTask::ProbabilityResultType::Items)); ui->chkGenerateMaps->setChecked(types.testFlag(BaristaInferenceTask::ProbabilityResultType::Maps)); } - -Label* BaristaInferenceTaskWidget::getLabel() const -{ - return ui->lstLabel->getSelectedLabel(); -} - -Block* BaristaInferenceTaskWidget::getCanvasBlock() const -{ - return ui->lstCanvasBlock->getSelectedCanvasBlock(); -} - -std::vector<const ImageReference*> BaristaInferenceTaskWidget::getImageReferences() const -{ - return ui->lstImageReferences->getSelectedImageReferences(); -} diff --git a/Grinder/ui/barista/tasks/BaristaInferenceTaskWidget.h b/Grinder/ui/barista/tasks/BaristaInferenceTaskWidget.h index a42ada96cc298b6010415108e745ca22a33116f0..3ec66689f8daf502870f9a7cea1cc25571a6546e 100644 --- a/Grinder/ui/barista/tasks/BaristaInferenceTaskWidget.h +++ b/Grinder/ui/barista/tasks/BaristaInferenceTaskWidget.h @@ -35,7 +35,6 @@ namespace grndr virtual void applySettings(bool save) override; private slots: - void on_lstLabel_currentIndexChanged(int index); void on_txtOutputDir_editingFinished(); private: @@ -44,16 +43,9 @@ namespace grndr private: void fillNetworkStateFiles(); - void fillLabels(const LabelVector& labels); - void fillCanvasBlocks(const Label* label); - void fillImageReferences(const ImageReferenceVector& imageRefs); BaristaInferenceTask::ProbabilityResultTypes getProbabilityResultTypes() const; void setProbabilityResultTypes(BaristaInferenceTask::ProbabilityResultTypes types); - - Label* getLabel() const; - Block* getCanvasBlock() const; - std::vector<const ImageReference*> getImageReferences() const; }; } diff --git a/Grinder/ui/barista/tasks/BaristaInferenceTaskWidget.ui b/Grinder/ui/barista/tasks/BaristaInferenceTaskWidget.ui index 772ba35055cce86fb7a08d09f8f61db40e51b660..9c059b023d5ab9031c7f7259b3970caa9c403689 100644 --- a/Grinder/ui/barista/tasks/BaristaInferenceTaskWidget.ui +++ b/Grinder/ui/barista/tasks/BaristaInferenceTaskWidget.ui @@ -7,7 +7,7 @@ <x>0</x> <y>0</y> <width>404</width> - <height>666</height> + <height>367</height> </rect> </property> <property name="windowTitle"> @@ -141,90 +141,6 @@ </widget> </item> <item row="2" column="0"> - <widget class="QGroupBox" name="groupBox_2"> - <property name="title"> - <string>Image settings</string> - </property> - <layout class="QGridLayout" name="gridLayout"> - <item row="0" column="0"> - <widget class="QLabel" name="label_6"> - <property name="text"> - <string>Label:</string> - </property> - <property name="buddy"> - <cstring>lstLabel</cstring> - </property> - </widget> - </item> - <item row="0" column="1"> - <widget class="LabelsComboBox" name="lstLabel"> - <property name="sizePolicy"> - <sizepolicy hsizetype="MinimumExpanding" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="minimumSize"> - <size> - <width>150</width> - <height>0</height> - </size> - </property> - </widget> - </item> - <item row="1" column="0"> - <widget class="QLabel" name="label_5"> - <property name="text"> - <string>Canvas block:</string> - </property> - <property name="buddy"> - <cstring>lstCanvasBlock</cstring> - </property> - </widget> - </item> - <item row="1" column="1"> - <widget class="CanvasBlocksComboBox" name="lstCanvasBlock"> - <property name="sizePolicy"> - <sizepolicy hsizetype="MinimumExpanding" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="minimumSize"> - <size> - <width>150</width> - <height>0</height> - </size> - </property> - </widget> - </item> - <item row="2" column="1"> - <spacer name="verticalSpacer_5"> - <property name="orientation"> - <enum>Qt::Vertical</enum> - </property> - <property name="sizeType"> - <enum>QSizePolicy::Fixed</enum> - </property> - <property name="sizeHint" stdset="0"> - <size> - <width>20</width> - <height>6</height> - </size> - </property> - </spacer> - </item> - <item row="3" column="0" colspan="2"> - <widget class="ImageReferencesCheckListWidget" name="lstImageReferences"> - <property name="selectionMode"> - <enum>QAbstractItemView::ExtendedSelection</enum> - </property> - </widget> - </item> - </layout> - </widget> - </item> - <item row="3" column="0"> <widget class="QGroupBox" name="groupBox_4"> <property name="title"> <string>Result settings</string> @@ -310,21 +226,6 @@ </layout> </widget> <customwidgets> - <customwidget> - <class>LabelsComboBox</class> - <extends>QComboBox</extends> - <header>ui/widgets/project/LabelsComboBox.h</header> - </customwidget> - <customwidget> - <class>CanvasBlocksComboBox</class> - <extends>QComboBox</extends> - <header>ui/widgets/engine/CanvasBlocksComboBox.h</header> - </customwidget> - <customwidget> - <class>ImageReferencesCheckListWidget</class> - <extends>QListWidget</extends> - <header>ui/widgets/project/ImageReferencesCheckListWidget.h</header> - </customwidget> <customwidget> <class>BaristaNetworksComboBox</class> <extends>QComboBox</extends> @@ -338,9 +239,6 @@ <tabstop>txtOutputDir</tabstop> <tabstop>txtRemoteDir</tabstop> <tabstop>lstNetworkState</tabstop> - <tabstop>lstLabel</tabstop> - <tabstop>lstCanvasBlock</tabstop> - <tabstop>lstImageReferences</tabstop> <tabstop>chkGenerateItems</tabstop> <tabstop>sldThreshold</tabstop> <tabstop>txtThreshold</tabstop> diff --git a/Grinder/ui/barista/tasks/BaristaTrainingTaskWidget.cpp b/Grinder/ui/barista/tasks/BaristaTrainingTaskWidget.cpp index 88b6a614eb3fe78dd8793c553efac4c76c42423e..e036d08e04054a50dd9a2513b885ad8dafa0060b 100644 --- a/Grinder/ui/barista/tasks/BaristaTrainingTaskWidget.cpp +++ b/Grinder/ui/barista/tasks/BaristaTrainingTaskWidget.cpp @@ -31,25 +31,6 @@ void BaristaTrainingTaskWidget::verifySettings() if (ui->txtOutputDir->text().isEmpty()) showError("Please enter an output directory.", ui->txtOutputDir); - - if (!getLabel()) - showError("Please select a label.", ui->lstLabel); - - if (auto canvasBlock = getCanvasBlock()) - { - bool hasTags = false; - - if (auto imageTagsProperty = canvasBlock->portProperty<ImageTagsProperty>(PortType::ImageTagsIn, PropertyID::ImageTags)) - hasTags = !imageTagsProperty->object().tags().empty(); - - if (!hasTags) - showError("The selected canvas block doesn't have any tags assigned.", ui->lstCanvasBlock); - } - else - showError("Please select a canvas block.", ui->lstCanvasBlock); - - if (getImageReferences().empty()) - showError("Please select at least one image.", ui->lstImageReferences); } void BaristaTrainingTaskWidget::applySettings(bool save) @@ -66,10 +47,6 @@ void BaristaTrainingTaskWidget::applySettings(bool save) _task->setMaxIterations(ui->txtMaxIterations->value()); _task->setDisplayInterval(ui->txtDisplayInterval->value()); _task->setSnapshotInterval(ui->txtSnapshotInterval->value()); - - _task->setLabel(getLabel()); - _task->setCanvasBlock(getCanvasBlock()); - _task->setImageReferences(getImageReferences()); } else { @@ -83,55 +60,12 @@ void BaristaTrainingTaskWidget::applySettings(bool save) ui->txtMaxIterations->setValue(_task->getMaxIterations()); ui->txtDisplayInterval->setValue(_task->getDisplayInterval()); ui->txtSnapshotInterval->setValue(_task->getSnapshotInterval()); - - ui->lstLabel->selectLabel(_task->getLabel()); - ui->lstCanvasBlock->selectCanvasBlock(_task->getCanvasBlock()); - ui->lstImageReferences->selectImageReferences(_task->getImageReferences()); } } -void BaristaTrainingTaskWidget::on_lstLabel_currentIndexChanged(int index) -{ - Q_UNUSED(index); - fillCanvasBlocks(ui->lstLabel->getSelectedLabel()); -} - void BaristaTrainingTaskWidget::setupUi() { ui->setupUi(this); ui->lstNetworks->populate(); - - fillLabels(grinder()->project().labels()); - fillImageReferences(grinder()->project().imageReferences()); -} - -void BaristaTrainingTaskWidget::fillLabels(const LabelVector& labels) -{ - ui->lstLabel->populate(labels); -} - -void BaristaTrainingTaskWidget::fillCanvasBlocks(const Label* label) -{ - ui->lstCanvasBlock->populate(label); -} - -void BaristaTrainingTaskWidget::fillImageReferences(const ImageReferenceVector& imageRefs) -{ - ui->lstImageReferences->populate(imageRefs); -} - -Label* BaristaTrainingTaskWidget::getLabel() const -{ - return ui->lstLabel->getSelectedLabel(); -} - -Block* BaristaTrainingTaskWidget::getCanvasBlock() const -{ - return ui->lstCanvasBlock->getSelectedCanvasBlock(); -} - -std::vector<const ImageReference*> BaristaTrainingTaskWidget::getImageReferences() const -{ - return ui->lstImageReferences->getSelectedImageReferences(); } diff --git a/Grinder/ui/barista/tasks/BaristaTrainingTaskWidget.h b/Grinder/ui/barista/tasks/BaristaTrainingTaskWidget.h index bd4b32b27d7c459f781f62e43bf8561afa7b9c63..818b80e9e15846844859e5033375146ebe3d3015 100644 --- a/Grinder/ui/barista/tasks/BaristaTrainingTaskWidget.h +++ b/Grinder/ui/barista/tasks/BaristaTrainingTaskWidget.h @@ -34,21 +34,9 @@ namespace grndr virtual void verifySettings() override; virtual void applySettings(bool save) override; - private slots: - void on_lstLabel_currentIndexChanged(int index); - private: Ui::BaristaTrainingTaskWidget *ui; void setupUi(); - - private: - void fillLabels(const LabelVector& labels); - void fillCanvasBlocks(const Label* label); - void fillImageReferences(const ImageReferenceVector& imageRefs); - - Label* getLabel() const; - Block* getCanvasBlock() const; - std::vector<const ImageReference*> getImageReferences() const; }; } diff --git a/Grinder/ui/barista/tasks/BaristaTrainingTaskWidget.ui b/Grinder/ui/barista/tasks/BaristaTrainingTaskWidget.ui index 9cb40ff6274c21bf982a32f79c43fa32f1c200ae..9845ce88a8aa34f4caa806a26cd3a96b38bb5826 100644 --- a/Grinder/ui/barista/tasks/BaristaTrainingTaskWidget.ui +++ b/Grinder/ui/barista/tasks/BaristaTrainingTaskWidget.ui @@ -7,7 +7,7 @@ <x>0</x> <y>0</y> <width>386</width> - <height>610</height> + <height>320</height> </rect> </property> <property name="windowTitle"> @@ -26,90 +26,6 @@ <property name="bottomMargin"> <number>0</number> </property> - <item row="3" column="0"> - <widget class="QGroupBox" name="groupBox_3"> - <property name="title"> - <string>Image settings</string> - </property> - <layout class="QGridLayout" name="gridLayout_3"> - <item row="1" column="0"> - <widget class="QLabel" name="label_5"> - <property name="text"> - <string>Canvas block:</string> - </property> - <property name="buddy"> - <cstring>lstCanvasBlock</cstring> - </property> - </widget> - </item> - <item row="1" column="1"> - <widget class="CanvasBlocksComboBox" name="lstCanvasBlock"> - <property name="sizePolicy"> - <sizepolicy hsizetype="MinimumExpanding" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="minimumSize"> - <size> - <width>150</width> - <height>0</height> - </size> - </property> - </widget> - </item> - <item row="0" column="1"> - <widget class="LabelsComboBox" name="lstLabel"> - <property name="sizePolicy"> - <sizepolicy hsizetype="MinimumExpanding" vsizetype="Fixed"> - <horstretch>0</horstretch> - <verstretch>0</verstretch> - </sizepolicy> - </property> - <property name="minimumSize"> - <size> - <width>150</width> - <height>0</height> - </size> - </property> - </widget> - </item> - <item row="4" column="0" colspan="2"> - <widget class="ImageReferencesCheckListWidget" name="lstImageReferences"> - <property name="selectionMode"> - <enum>QAbstractItemView::ExtendedSelection</enum> - </property> - </widget> - </item> - <item row="0" column="0"> - <widget class="QLabel" name="label_6"> - <property name="text"> - <string>Label:</string> - </property> - <property name="buddy"> - <cstring>lstLabel</cstring> - </property> - </widget> - </item> - <item row="2" column="1"> - <spacer name="verticalSpacer_5"> - <property name="orientation"> - <enum>Qt::Vertical</enum> - </property> - <property name="sizeType"> - <enum>QSizePolicy::Fixed</enum> - </property> - <property name="sizeHint" stdset="0"> - <size> - <width>20</width> - <height>6</height> - </size> - </property> - </spacer> - </item> - </layout> - </widget> - </item> <item row="0" column="0"> <widget class="QGroupBox" name="groupBox"> <property name="title"> @@ -326,21 +242,6 @@ </layout> </widget> <customwidgets> - <customwidget> - <class>LabelsComboBox</class> - <extends>QComboBox</extends> - <header>ui/widgets/project/LabelsComboBox.h</header> - </customwidget> - <customwidget> - <class>CanvasBlocksComboBox</class> - <extends>QComboBox</extends> - <header>ui/widgets/engine/CanvasBlocksComboBox.h</header> - </customwidget> - <customwidget> - <class>ImageReferencesCheckListWidget</class> - <extends>QListWidget</extends> - <header>ui/widgets/project/ImageReferencesCheckListWidget.h</header> - </customwidget> <customwidget> <class>BaristaNetworksComboBox</class> <extends>QComboBox</extends> @@ -356,9 +257,6 @@ <tabstop>txtMaxIterations</tabstop> <tabstop>txtDisplayInterval</tabstop> <tabstop>txtSnapshotInterval</tabstop> - <tabstop>lstLabel</tabstop> - <tabstop>lstCanvasBlock</tabstop> - <tabstop>lstImageReferences</tabstop> </tabstops> <resources/> <connections/> diff --git a/Grinder/util/FileUtils.cpp b/Grinder/util/FileUtils.cpp index 5c35f91cfcf5297e69e0782b00fb9f65166ec3d6..4a3e3702d2290bc8d48ccfa7a62f086d2e3a5e39 100644 --- a/Grinder/util/FileUtils.cpp +++ b/Grinder/util/FileUtils.cpp @@ -1,65 +1,73 @@ -/****************************************************************************** - * File: FileUtils.cpp - * Date: 10.2.2018 - *****************************************************************************/ - -#include "Grinder.h" -#include "FileUtils.h" - -bool FileUtils::compareFileNames(QString fileName1, QString fileName2) -{ - fileName1.replace('\\', '/'); - fileName2.replace('\\', '/'); - - return fileName1.toLower() == fileName2.toLower(); -} - -QStringList FileUtils::expandFileList(QStringList paths, QStringList filters) -{ - QStringList files; - - for (auto path : paths) - { - QFileInfo fileInfo{path}; - - if (fileInfo.isDir()) - { - QDirIterator dirIt{fileInfo.filePath(), filters, QDir::NoFilter, QDirIterator::Subdirectories}; - - while (dirIt.hasNext()) - { - auto file = dirIt.next(); - - if (dirIt.fileInfo().isFile()) - files << file; - } - } - else - files << path; - } - - return files; -} - -QStringList FileUtils::getDirectoryList(QString path) -{ - QStringList dirs; - QFileInfo fileInfo{path}; - - if (fileInfo.isDir()) - { - dirs << path; - - QDirIterator dirIt{fileInfo.filePath(), QDir::Dirs|QDir::NoDotAndDotDot, QDirIterator::Subdirectories}; - - while (dirIt.hasNext()) - { - auto dir = dirIt.next(); - - if (dirIt.fileInfo().isDir()) - dirs << dir; - } - } - - return dirs; -} +/****************************************************************************** + * File: FileUtils.cpp + * Date: 10.2.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "FileUtils.h" + +bool FileUtils::compareFileNames(QString fileName1, QString fileName2) +{ + fileName1.replace('\\', '/'); + fileName2.replace('\\', '/'); + + return fileName1.toLower() == fileName2.toLower(); +} + +QStringList FileUtils::expandFileList(QStringList paths, QStringList filters) +{ + QStringList files; + + for (auto path : paths) + { + QFileInfo fileInfo{path}; + + if (fileInfo.isDir()) + { + QDirIterator dirIt{fileInfo.filePath(), filters, QDir::NoFilter, QDirIterator::Subdirectories}; + + while (dirIt.hasNext()) + { + auto file = dirIt.next(); + + if (dirIt.fileInfo().isFile()) + files << file; + } + } + else + files << path; + } + + return files; +} + +QStringList FileUtils::getDirectoryList(QString path) +{ + QStringList dirs; + QFileInfo fileInfo{path}; + + if (fileInfo.isDir()) + { + dirs << path; + + QDirIterator dirIt{fileInfo.filePath(), QDir::Dirs|QDir::NoDotAndDotDot, QDirIterator::Subdirectories}; + + while (dirIt.hasNext()) + { + auto dir = dirIt.next(); + + if (dirIt.fileInfo().isDir()) + dirs << dir; + } + } + + return dirs; +} + +QString FileUtils::getTemporaryFileName(QString filename) +{ + QFileInfo fi{QDir{QStandardPaths::writableLocation(QStandardPaths::TempLocation)}, "Grinder"}; + QDir dir = fi.filePath(); + dir.mkpath(dir.path()); + return dir.filePath(filename); +} diff --git a/Grinder/util/FileUtils.h b/Grinder/util/FileUtils.h index fc2e2b857f2f368d6b59655ee03c063f4c9cf662..9898e5269fda75839c859357c5a05b0b4961a216 100644 --- a/Grinder/util/FileUtils.h +++ b/Grinder/util/FileUtils.h @@ -1,26 +1,28 @@ -/****************************************************************************** - * File: FileUtils.h - * Date: 10.2.2018 - *****************************************************************************/ - -#ifndef FILEUTILS_H -#define FILEUTILS_H - -#include <QStringList> - -namespace grndr -{ - class FileUtils final - { - public: - static bool compareFileNames(QString fileName1, QString fileName2); - - static QStringList expandFileList(QStringList paths, QStringList filters = QStringList{}); - static QStringList getDirectoryList(QString path); - - private: - FileUtils(); - }; -} - -#endif +/****************************************************************************** + * File: FileUtils.h + * Date: 10.2.2018 + *****************************************************************************/ + +#ifndef FILEUTILS_H +#define FILEUTILS_H + +#include <QStringList> + +namespace grndr +{ + class FileUtils final + { + public: + static bool compareFileNames(QString fileName1, QString fileName2); + + static QStringList expandFileList(QStringList paths, QStringList filters = QStringList{}); + static QStringList getDirectoryList(QString path); + + static QString getTemporaryFileName(QString filename); + + private: + FileUtils(); + }; +} + +#endif