diff --git a/Grinder/Grinder.pro b/Grinder/Grinder.pro index 39328b19cfee2430c3bde9ba8209f7d4e27f97fd..4728f4e8a28c5ca28e5becd62992d1cf37a9e189 100644 --- a/Grinder/Grinder.pro +++ b/Grinder/Grinder.pro @@ -347,9 +347,18 @@ SOURCES += \ ui/task/ConfigureTaskWidgetBase.cpp \ task/tasks/BaristaTask.cpp \ ui/task/tasks/GenericTaskWidget.cpp \ - task/tasks/BaristaMessage.cpp \ + barista/BaristaMessage.cpp \ task/tasks/BaristaTrainingTask.cpp \ - ui/task/tasks/BaristaTrainingTaskWidget.cpp + ui/task/tasks/BaristaTrainingTaskWidget.cpp \ + task/tasks/BaristaInferenceTask.cpp \ + ui/task/tasks/BaristaInferenceTaskWidget.cpp \ + barista/BaristaBinaryData.cpp \ + ui/widgets/LabelsComboBox.cpp \ + ui/widgets/OutputBlocksComboBox.cpp \ + ui/widgets/ImageReferencesCheckListWidget.cpp \ + barista/BaristaBinaryRecvData.cpp \ + barista/BaristaBinarySendData.cpp \ + barista/BaristaSocket.cpp HEADERS += \ ui/mainwnd/GrinderWindow.h \ @@ -742,10 +751,19 @@ HEADERS += \ ui/task/ConfigureTaskWidgetBase.h \ task/tasks/BaristaTask.h \ ui/task/tasks/GenericTaskWidget.h \ - task/tasks/BaristaMessage.h \ + barista/BaristaMessage.h \ task/tasks/BaristaTrainingTask.h \ ui/task/tasks/BaristaTrainingTaskWidget.h \ - task/tasks/BaristaProtocol.h + barista/BaristaProtocol.h \ + task/tasks/BaristaInferenceTask.h \ + ui/task/tasks/BaristaInferenceTaskWidget.h \ + barista/BaristaBinaryData.h \ + ui/widgets/LabelsComboBox.h \ + ui/widgets/OutputBlocksComboBox.h \ + ui/widgets/ImageReferencesCheckListWidget.h \ + barista/BaristaBinaryRecvData.h \ + barista/BaristaBinarySendData.h \ + barista/BaristaSocket.h FORMS += \ ui/mainwnd/GrinderWindow.ui \ @@ -762,7 +780,8 @@ FORMS += \ ui/dlg/TextViewerDialog.ui \ ui/task/ConfigureTaskDialog.ui \ ui/task/tasks/GenericTaskWidget.ui \ - ui/task/tasks/BaristaTrainingTaskWidget.ui + ui/task/tasks/BaristaTrainingTaskWidget.ui \ + ui/task/tasks/BaristaInferenceTaskWidget.ui RESOURCES += \ res/Grinder.qrc diff --git a/Grinder/Version.h b/Grinder/Version.h index 3989e7fabcec16d6b7e6611c07da2eaeb6505a14..7bed69dafe74383c9d69410bdbb4df72db04ccb1 100644 --- a/Grinder/Version.h +++ b/Grinder/Version.h @@ -10,14 +10,14 @@ #define GRNDR_INFO_TITLE "Grinder" #define GRNDR_INFO_COPYRIGHT "Copyright (c) WWU Muenster" -#define GRNDR_INFO_DATE "30.11.2018" +#define GRNDR_INFO_DATE "18.12.2018" #define GRNDR_INFO_COMPANY "WWU Muenster" #define GRNDR_INFO_WEBSITE "http://www.uni-muenster.de" #define GRNDR_VERSION_MAJOR 0 #define GRNDR_VERSION_MINOR 10 #define GRNDR_VERSION_REVISION 0 -#define GRNDR_VERSION_BUILD 291 +#define GRNDR_VERSION_BUILD 296 namespace grndr { diff --git a/Grinder/barista/BaristaBinaryData.cpp b/Grinder/barista/BaristaBinaryData.cpp new file mode 100644 index 0000000000000000000000000000000000000000..637860be1685e830c2e4ad2e875089045f81babe --- /dev/null +++ b/Grinder/barista/BaristaBinaryData.cpp @@ -0,0 +1,59 @@ +/****************************************************************************** + * File: BaristaBinaryMessage.cpp + * Date: 07.12.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "BaristaBinaryData.h" + +int BaristaBinaryData::getCVDataType(QString type) +{ + // Only a subset of all OpenCV types is supported + if (type.compare("float64", Qt::CaseInsensitive) == 0) + return CV_64FC1; + if (type.compare("float32", Qt::CaseInsensitive) == 0) + return CV_32FC1; + else if (type.compare("int32", Qt::CaseInsensitive) == 0 || type.compare("uint32", Qt::CaseInsensitive) == 0) + return CV_32SC1; + else if (type.compare("int16", Qt::CaseInsensitive) == 0) + return CV_16SC1; + else if (type.compare("int8", Qt::CaseInsensitive) == 0) + return CV_8SC1; + else if (type.compare("uint16", Qt::CaseInsensitive) == 0) + return CV_16UC1; + else if (type.compare("uint8", Qt::CaseInsensitive) == 0) + return CV_8UC1; + else + return CV_32FC1; // Always fall back to float32 +} + +QString BaristaBinaryData::getBaristaDataType(int type) +{ + // Only a subset of all OpenCV types is supported + switch (type) + { + case CV_64FC1: + return "float64"; + + case CV_32FC1: + return "float32"; + + case CV_32SC1: + return "int32"; + + case CV_16SC1: + return "int16"; + + case CV_8SC1: + return "int8"; + + case CV_16UC1: + return "uint16"; + + case CV_8UC1: + return "uint8"; + + default: + return "float32"; // Always fall back to float32 + } +} diff --git a/Grinder/barista/BaristaBinaryData.h b/Grinder/barista/BaristaBinaryData.h new file mode 100644 index 0000000000000000000000000000000000000000..08bc1dfbe3895e1fc70cbbada4f265fa88734793 --- /dev/null +++ b/Grinder/barista/BaristaBinaryData.h @@ -0,0 +1,24 @@ +/****************************************************************************** + * File: BaristaBinaryMessage.h + * Date: 07.12.2018 + *****************************************************************************/ + +#ifndef BARISTABINARYDATA_H +#define BARISTABINARYDATA_H + +#include <zmq.hpp> +#include <opencv2/core.hpp> + +#include "common/serialization/SettingsContainer.h" + +namespace grndr +{ + class BaristaBinaryData + { + public: + static int getCVDataType(QString type); + static QString getBaristaDataType(int type); + }; +} + +#endif diff --git a/Grinder/barista/BaristaBinaryRecvData.cpp b/Grinder/barista/BaristaBinaryRecvData.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0949e25fd188a5dce6873544987e9cc436f40c45 --- /dev/null +++ b/Grinder/barista/BaristaBinaryRecvData.cpp @@ -0,0 +1,119 @@ +/****************************************************************************** + * File: BaristaBinaryRecvData.cpp + * Date: 17.12.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "BaristaBinaryRecvData.h" +#include "barista/BaristaSocket.h" + +BaristaBinaryRecvData::BaristaBinaryRecvData(const SettingsContainer& messageData, BaristaSocket& socket) +{ + receive(messageData, socket); +} + +void BaristaBinaryRecvData::receive(const SettingsContainer& messageData, BaristaSocket& socket) +{ + // Get the data layout first + _layout.dataType = messageData("dtype").toString(); + + if (auto shape = messageData.child("shape")) + { + for (auto dim : shape->values()) + _layout.dimensions.push_back(dim.toInt()); + } + + if (_layout.dataType.isEmpty()) + throw std::runtime_error{"Invalid data type"}; + + if (_layout.dimensions.empty()) + throw std::runtime_error{"Invalid dimensions"}; + + // Find all "raw" values + std::map<int, QString> rawValues; + auto values = messageData.values(); + + for (auto it = values.cbegin(); it != values.cend(); ++it) + { + if (it.key().startsWith("raw_", Qt::CaseInsensitive)) + rawValues[it.value().toInt()] = it.key(); + } + + // Get all raw data in order + for (auto rawValue : rawValues) + { + zmq::message_t message; + socket.recvMessage(message); + + _binaryMessages[rawValue.second] = std::move(message); + } +} + +bool BaristaBinaryRecvData::checkDataType(QString type) const +{ + return _layout.dataType.compare(type, Qt::CaseInsensitive) == 0; +} + +bool BaristaBinaryRecvData::checkDataType(int type) const +{ + return checkDataType(getBaristaDataType(type)); +} + +bool BaristaBinaryRecvData::checkDimensions(std::vector<int> dims) const +{ + if (dims.size() == _layout.dimensions.size()) + return std::equal(dims.cbegin(), dims.cend(), _layout.dimensions.cbegin()); + else + return false; +} + +QStringList BaristaBinaryRecvData::getDataNames() const +{ + QStringList names; + + for (const auto& binaryMessage : _binaryMessages) + names << binaryMessage.first; + + return names; +} + +zmq::message_t* BaristaBinaryRecvData::data(QString key) +{ + auto it = _binaryMessages.find(key); + + if (it != _binaryMessages.cend()) + return &it->second; + else + return nullptr; +} + +const zmq::message_t* BaristaBinaryRecvData::data(QString key) const +{ + auto it = _binaryMessages.find(key); + + if (it != _binaryMessages.cend()) + return &it->second; + else + return nullptr; +} + +cv::Mat BaristaBinaryRecvData::matrix(QString key) const +{ + cv::Mat mat{_layout.dimensions, getCVDataType(_layout.dataType), cv::Scalar::all(0)}; + + auto it = _binaryMessages.find(key); + + if (it != _binaryMessages.cend()) + { + // Directly copy the message data to the matrix, ensuring that the matrix is continous and big enough + if (mat.isContinuous() && (mat.total() * mat.elemSize() >= it->second.size())) + { + std::memcpy(mat.ptr(), it->second.data(), it->second.size()); + return mat; + } + else + throw std::runtime_error{"Invalid generated matrix"}; + } + else + return {}; +} diff --git a/Grinder/barista/BaristaBinaryRecvData.h b/Grinder/barista/BaristaBinaryRecvData.h new file mode 100644 index 0000000000000000000000000000000000000000..b4e80eefe1598e7f0d27d7560fb9fe3e1bfd92c0 --- /dev/null +++ b/Grinder/barista/BaristaBinaryRecvData.h @@ -0,0 +1,49 @@ +/****************************************************************************** + * File: BaristaBinaryRecvData.h + * Date: 17.12.2018 + *****************************************************************************/ + +#ifndef BARISTABINARYRECVDATA_H +#define BARISTABINARYRECVDATA_H + +#include "BaristaBinaryData.h" + +namespace grndr +{ + class BaristaSocket; + + class BaristaBinaryRecvData : public BaristaBinaryData + { + public: + BaristaBinaryRecvData(const SettingsContainer& messageData, BaristaSocket& socket); + + public: + void receive(const SettingsContainer& messageData, BaristaSocket& socket); + + public: + bool checkDataType(QString type) const; + bool checkDataType(int type) const; + bool checkDimensions(std::vector<int> dims) const; + + public: + QString getDataType() const { return _layout.dataType; } + std::vector<int> getDimensions() const { return _layout.dimensions; } + + QStringList getDataNames() const; + + zmq::message_t* data(QString key); + const zmq::message_t* data(QString key) const; + cv::Mat matrix(QString key) const; + + private: + struct + { + QString dataType{""}; + std::vector<int> dimensions; + } _layout; + + std::map<QString, zmq::message_t> _binaryMessages; + }; +} + +#endif diff --git a/Grinder/barista/BaristaBinarySendData.cpp b/Grinder/barista/BaristaBinarySendData.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5850cafa17570b19e1700275733286ef069eab81 --- /dev/null +++ b/Grinder/barista/BaristaBinarySendData.cpp @@ -0,0 +1,82 @@ +/****************************************************************************** + * File: BaristaBinarySendData.cpp + * Date: 17.12.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "BaristaBinarySendData.h" +#include "barista/BaristaSocket.h" +#include "core/GrinderApplication.h" + +BaristaBinarySendData::BaristaBinarySendData(Label* label, Block* outputBlock, const ImageReference* imageRef) : + _label{label}, _outputBlock{outputBlock}, _imageReference{imageRef} +{ + if (!imageRef) + throw std::invalid_argument{_EXCPT("imageRef may not be null")}; +} + +void BaristaBinarySendData::send(BaristaSocket& socket) +{ + // Send all binary messages in order + for (unsigned int i = 0; i < _binaryMessages.size(); ++i) + socket.sendMessage(_binaryMessages[i], true, true, i < _binaryMessages.size() - 1); +} + +void BaristaBinarySendData::addData(QSize imageSize, bool colorImage) +{ + cv::Mat imgData; + + if (_label && _outputBlock) // Process the pipeline using imgRef as the active image and grab the output from the provided block + { + if (auto outputPort = _outputBlock->ports().selectByType(PortType::ImageOut)) + imgData = grinder()->engineController().executeLabelEx(_label, outputPort.get(), _imageReference, Engine::ExecutionMode::Execute); + else + throw std::runtime_error{QString{"Output block '%1' has no image out port"}.arg(_outputBlock->getFormattedName()).toStdString()}; + } + else // Simply load the plain image + imgData = _imageReference->loadImage(); + + if (imgData.empty()) + throw std::runtime_error{QString{"Failed to get the data for '%1'"}.arg(_imageReference->getImageFilePath()).toStdString()}; + + if (imgData.channels() != 3) + throw std::runtime_error{QString{"The data for '%1' is invalid"}.arg(_imageReference->getImageFilePath()).toStdString()}; + + // Convert the image data to float and map intensities to [0,1] + imgData.convertTo(imgData, CV_32F); + cv::normalize(imgData, imgData, 0, 1, cv::NORM_MINMAX); + + // Convert colors if necessary + if (imgData.channels() == 1 && colorImage) + cv::cvtColor(imgData, imgData, cv::COLOR_GRAY2BGR); + else if (imgData.channels() == 3 && !colorImage) + cv::cvtColor(imgData, imgData, cv::COLOR_BGR2GRAY); + + // Resize the image if necessary + if (imgData.rows != imageSize.height() || imgData.cols != imageSize.width()) + cv::resize(imgData, imgData, cv::Size{imageSize.width(), imageSize.height()}, cv::INTER_AREA); + + // Copy the image data to a new 4D matrix + cv::Mat mat{std::vector<int>{1, imgData.channels(), imgData.rows, imgData.cols}, CV_32FC1, cv::Scalar::all(0)}; + + if (imgData.channels() == 3) + { + cv::Mat imageChannels[3]; // BGR order + cv::split(imgData, imageChannels); + + for (int i = 0; i < 3; ++i) + copyImageChannel(imageChannels[i], mat, i); + } + else if (imgData.channels() == 1) + copyImageChannel(imgData, mat, 0); + + // Create a ZMQ message based on the matrix data + _binaryMessages.emplace_back(mat.ptr(), mat.total() * mat.elemSize()); +} + +void BaristaBinarySendData::copyImageChannel(const cv::Mat& imgChannel, cv::Mat& mat, int channel) const +{ + // Copy all rows individually + for (int r = 0; r < imgChannel.rows; ++r) + std::memcpy(mat.ptr(0, channel, r), imgChannel.ptr(r), imgChannel.elemSize() * imgChannel.cols); +} diff --git a/Grinder/barista/BaristaBinarySendData.h b/Grinder/barista/BaristaBinarySendData.h new file mode 100644 index 0000000000000000000000000000000000000000..f9124d147886fb7b88775fa06c24d1dd8eefb543 --- /dev/null +++ b/Grinder/barista/BaristaBinarySendData.h @@ -0,0 +1,43 @@ +/****************************************************************************** + * File: BaristaBinarySendData.h + * Date: 17.12.2018 + *****************************************************************************/ + +#ifndef BARISTABINARYSENDDATA_H +#define BARISTABINARYSENDDATA_H + +#include "BaristaBinaryData.h" + +namespace grndr +{ + class Label; + class Block; + class ImageReference; + class BaristaSocket; + + class BaristaBinarySendData : public BaristaBinaryData + { + public: + BaristaBinarySendData(Label* label, Block* outputBlock, const ImageReference* imageRef); + + public: + void send(BaristaSocket& socket); + + public: + void addData(QSize imageSize, bool colorImage); + + unsigned int getMessageCount() const { return _binaryMessages.size(); } + + private: + void copyImageChannel(const cv::Mat& imgChannel, cv::Mat& mat, int channel) const; + + private: + Label* _label{nullptr}; + Block* _outputBlock{nullptr}; + const ImageReference* _imageReference{nullptr}; + + std::vector<zmq::message_t> _binaryMessages; + }; +} + +#endif diff --git a/Grinder/task/tasks/BaristaMessage.cpp b/Grinder/barista/BaristaMessage.cpp similarity index 100% rename from Grinder/task/tasks/BaristaMessage.cpp rename to Grinder/barista/BaristaMessage.cpp diff --git a/Grinder/task/tasks/BaristaMessage.h b/Grinder/barista/BaristaMessage.h similarity index 100% rename from Grinder/task/tasks/BaristaMessage.h rename to Grinder/barista/BaristaMessage.h diff --git a/Grinder/task/tasks/BaristaProtocol.h b/Grinder/barista/BaristaProtocol.h similarity index 78% rename from Grinder/task/tasks/BaristaProtocol.h rename to Grinder/barista/BaristaProtocol.h index 97dd2378e2d8561281cf6c85630d30f485932f95..a205227b5ecd943dc753bc8cc20f1385954e2e50 100644 --- a/Grinder/task/tasks/BaristaProtocol.h +++ b/Grinder/barista/BaristaProtocol.h @@ -6,14 +6,20 @@ #ifndef BARISTAPROTOCOL_H #define BARISTAPROTOCOL_H +// General #define BARISTA_COMMAND_SETLIBRARY "setlibrarypath" +#define BARISTA_COMMAND_GETTESTARRAY "gettestarray" +#define BARISTA_COMMAND_SHUTDOWN "shutdown" +// Training #define BARISTA_COMMAND_STARTTRAINING "starttraining" #define BARISTA_COMMAND_PAUSETRAINING "pausetraining" #define BARISTA_COMMAND_RESUMETRAINING "starttraining" #define BARISTA_COMMAND_TRAININGDONE "trainingfinished" #define BARISTA_COMMAND_UPDATE "iterationupdate" -#define BARISTA_COMMAND_SHUTDOWN "shutdown" +// Inference +#define BARISTA_COMMAND_LOADNETWORK "loadnetwork" +#define BARISTA_COMMAND_INFER "infer" #endif diff --git a/Grinder/barista/BaristaSocket.cpp b/Grinder/barista/BaristaSocket.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d854dc99ca6eefcd8ca5e9a1548505ae997b745d --- /dev/null +++ b/Grinder/barista/BaristaSocket.cpp @@ -0,0 +1,71 @@ +/****************************************************************************** + * File: BaristaSocket.cpp + * Date: 18.12.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "BaristaSocket.h" + +BaristaSocket::BaristaSocket(zmq::context_t& context, int type) : socket_t(context, type) +{ + // Make sure that all pending messages are sent on close + int linger = -1; + setsockopt(ZMQ_LINGER, &linger, sizeof(linger)); +} + +void BaristaSocket::bindToPort(unsigned int port) +{ + bind(QString{"tcp://*:%1"}.arg(port).toStdString()); +} + +void BaristaSocket::subscribeToEverything() +{ + setsockopt(ZMQ_SUBSCRIBE, "", 0); +} + +void BaristaSocket::sendMessage(zmq::message_t& message, bool receiveAck, bool blocking, bool sendMore) +{ + try { +#if defined(QT_DEBUG) + QString rawData = QString::fromLatin1(static_cast<char*>(message.data()), message.size()); + qDebug() << "OUT:" << rawData; +#endif + int flags = 0; + + if (!blocking) + flags |= ZMQ_DONTWAIT; + + if (sendMore) + flags |= ZMQ_SNDMORE; + + send(message, flags); + + if (receiveAck && !sendMore) + { + // Receive an ACK message from Barista; we are not really interested in it, though, so we just ignore it + zmq::message_t ackMessage; + recvMessage(ackMessage, blocking); + } + } catch (std::exception& e) { + // Ignore ZMQ errors here + } +} + +void BaristaSocket::recvMessage(zmq::message_t& message, bool blocking) +{ + try { + int flags = 0; + + if (!blocking) + flags |= ZMQ_DONTWAIT; + + recv(&message, flags); + +#if defined(QT_DEBUG) + QString rawData = QString::fromLatin1(static_cast<char*>(message.data()), message.size()); + qDebug() << "IN:" << rawData; +#endif + } catch (std::exception& e) { + // Ignore ZMQ errors here + } +} diff --git a/Grinder/barista/BaristaSocket.h b/Grinder/barista/BaristaSocket.h new file mode 100644 index 0000000000000000000000000000000000000000..529f924dd478d99fe195b0a94bde274ad5cf8d45 --- /dev/null +++ b/Grinder/barista/BaristaSocket.h @@ -0,0 +1,28 @@ +/****************************************************************************** + * File: BaristaSocket.h + * Date: 18.12.2018 + *****************************************************************************/ + +#ifndef BARISTASOCKET_H +#define BARISTASOCKET_H + +#include <zmq.hpp> + +namespace grndr +{ + class BaristaSocket : public zmq::socket_t + { + public: + BaristaSocket(zmq::context_t& context, int type); + + public: + void bindToPort(unsigned int port); + void subscribeToEverything(); + + public: + void sendMessage(zmq::message_t& message, bool receiveAck = true, bool blocking = true, bool sendMore = false); + void recvMessage(zmq::message_t& message, bool blocking = true); + }; +} + +#endif diff --git a/Grinder/common/serialization/DeserializationContext.h b/Grinder/common/serialization/DeserializationContext.h index 43d1bad6b2b3fe0db979802c8c1808b0ed972050..b02f03afe8eaabae40f8f367177e7b27e4d737b1 100644 --- a/Grinder/common/serialization/DeserializationContext.h +++ b/Grinder/common/serialization/DeserializationContext.h @@ -13,6 +13,7 @@ namespace grndr { + class Label; class Block; class Port; class ImageReference; @@ -32,6 +33,9 @@ namespace grndr void endGroup() { _settingsStack.pop(); } public: + void addLabel(int index, Label* label) { _labels[index] = label; } + Label* getLabel(int index) const { return getObject(index, _labels); } + void addBlock(int index, Block* block) { _blocks[index] = block; } Block* getBlock(int index) const { return getObject(index, _blocks); } @@ -52,6 +56,7 @@ namespace grndr SettingsContainer _settings; std::stack<SettingsContainer*> _settingsStack; + std::map<int, Label*> _labels; std::map<int, Block*> _blocks; std::map<int, Port*> _ports; std::map<int, ImageReference*> _imageReferences; diff --git a/Grinder/common/serialization/JsonSettingsCodec.cpp b/Grinder/common/serialization/JsonSettingsCodec.cpp index 4ee5c7b40970d40690b77938ef5768447f5ca614..9ba79a1312eb4e480fc7c568f6b9452d0674e8bc 100644 --- a/Grinder/common/serialization/JsonSettingsCodec.cpp +++ b/Grinder/common/serialization/JsonSettingsCodec.cpp @@ -92,7 +92,7 @@ void JsonSettingsCodec::writeJsonValues(const SettingsContainer& settings, QText for (int i = 0; i < keys.size(); ++i) { auto valueName = keys[i]; - writeJsonValue(valueName, settings.values()[valueName], stream, level, i + 1 >= keys.size()); + writeJsonValue(!settings.isValueArray() ? valueName : "", settings.values()[valueName], stream, level, i + 1 >= keys.size()); } } @@ -116,7 +116,10 @@ void JsonSettingsCodec::writeJsonValue(QString valueName, const QVariant& value, break; } - writeJsonLine(QString{"\"%1\": %2%3"}.arg(valueName).arg(formattedValue).arg(!lastEntry ? "," : ""), stream, level); + if (!valueName.isEmpty()) + writeJsonLine(QString{"\"%1\": %2%3"}.arg(valueName).arg(formattedValue).arg(!lastEntry ? "," : ""), stream, level); + else + writeJsonLine(QString{"%1%2"}.arg(formattedValue).arg(!lastEntry ? "," : ""), stream, level); } void JsonSettingsCodec::writeJsonLine(const QString& text, QTextStream& stream, int level) const @@ -149,7 +152,7 @@ std::vector<JsonSettingsCodec::Token> JsonSettingsCodec::readJsonDocument(QTextS TokenizerEntry{R"(\])", TokenType::ArrayEnd}, TokenizerEntry{R"(:)", TokenType::Colon}, TokenizerEntry{R"(,)", TokenType::Comma}, - TokenizerEntry{R"([^,{}]+)", TokenType::Value}, + TokenizerEntry{R"([^,{}\[\]]+)", TokenType::Value}, }; // Tokenize the entire document, line by line @@ -259,28 +262,57 @@ void JsonSettingsCodec::parseJsonObjectItem(SettingsContainer& settings, const s void JsonSettingsCodec::parseJsonArray(SettingsContainer& settings, const std::vector<Token>& tokens, ParseContext& parseContext) const { - // Array ::= '[' Object {',' Object} ']' | '[' ']' + // Array ::= '[' Object {',' Object} ']' | '[' Value|Array {',' Value|Array} ']' | '[' ']' rdpExpect(TokenType::ArrayBegin, tokens, parseContext, [&settings, &parseContext](QString) { parseContext.pushChildSettings(settings, true); }); if (!rdpAccept(TokenType::ArrayEnd, tokens, parseContext, nullptr, true)) // Check for an empty array { - do { - // Handle the array elements which are encapsulated as individual objects - SettingsContainer arrayElemContainer; - parseContext.settingsStack.push(&arrayElemContainer); - parseContext.identifier = "$_array_elem_$"; - parseJsonObject(settings, tokens, parseContext); - parseContext.settingsStack.pop(); - - // Extract the elements from their encapsulating objects - for (auto arrayElem : arrayElemContainer.children("$_array_elem_$")) - { - if (arrayElem->children().size() == 1) // There must be exactly one child object in the encapsulating object - *parseContext.settingsStack.top() << std::move(*arrayElem->children().at(0)); - else - throw SerializationException{"Invalid container object found"}; - } - } while (rdpAccept(TokenType::Comma, tokens, parseContext)); + // Peek at the next token to decide what kind of array elements to expect + if (rdpAccept(TokenType::ObjectBegin, tokens, parseContext, nullptr, true)) + { + do { + // Handle the array elements which are encapsulated as individual objects + SettingsContainer arrayElemContainer; + parseContext.settingsStack.push(&arrayElemContainer); + parseContext.identifier = "$_array_elem_$"; + parseJsonObject(settings, tokens, parseContext); + parseContext.settingsStack.pop(); + + // Extract the elements from their encapsulating objects + for (auto arrayElem : arrayElemContainer.children("$_array_elem_$")) + { + if (arrayElem->children().size() == 1) // There must be exactly one child object in the encapsulating object + *parseContext.settingsStack.top() << std::move(*arrayElem->children().at(0)); + else + throw SerializationException{"Invalid container object found"}; + } + } while (rdpAccept(TokenType::Comma, tokens, parseContext)); + } + else + { + QStringList values; + int childIndex = 0; + + do { + // Check the next token type and handle them accordingly + if (rdpAccept(TokenType::ArrayBegin, tokens, parseContext, nullptr, true)) + { + parseContext.identifier = QString{"child%1"}.arg(childIndex++); + parseJsonArray(settings, tokens, parseContext); + } + else if (rdpAccept(TokenType::String, tokens, parseContext, nullptr, true)) + rdpExpect(TokenType::String, tokens, parseContext, [&values](QString value) { removeEnclosingQuotes(value); values << value; }); + else if (rdpAccept(TokenType::Value, tokens, parseContext, nullptr, true)) + rdpExpect(TokenType::Value, tokens, parseContext, [&values](QString value) { values << value; }); + + // Add all collected values + for (int i = 0; i < values.size(); ++i) + { + parseContext.identifier = QString{"value%1"}.arg(i); + parseContext.setSettingsValue(values[i]); + } + } while (rdpAccept(TokenType::Comma, tokens, parseContext)); + } } rdpExpect(TokenType::ArrayEnd, tokens, parseContext, [&parseContext](QString) { parseContext.popChildSettings(); }); diff --git a/Grinder/common/serialization/SerializationContext.h b/Grinder/common/serialization/SerializationContext.h index 02303aeb6f784b3bf21d343461d5862d172a0dd6..a15979862e2c01e724843590c280667754496ac7 100644 --- a/Grinder/common/serialization/SerializationContext.h +++ b/Grinder/common/serialization/SerializationContext.h @@ -14,6 +14,7 @@ namespace grndr { class Project; + class Label; class Block; class Port; class Connection; @@ -33,6 +34,9 @@ namespace grndr void endGroup(); public: + int addLabel(const Label* label) { return addObject(label, _labels); } + int getLabelIndex(const Label* label) const { return getObjectIndex(label, _labels); } + int addBlock(const Block* block) { return addObject(block, _blocks); } int getBlockIndex(const Block* block) const { return getObjectIndex(block, _blocks); } @@ -56,6 +60,7 @@ namespace grndr SettingsContainer _settings; std::stack<SettingsContainer> _settingsStack; + std::vector<const Label*> _labels; std::vector<const Block*> _blocks; std::vector<const Port*> _ports; std::vector<const Connection*> _connections; diff --git a/Grinder/common/serialization/SettingsContainer.cpp b/Grinder/common/serialization/SettingsContainer.cpp index 7d8bc3f799a3c07354c6fc8c775424bae17f0da4..e41c9948e3387f20df40a90f4c893616d20f9e56 100644 --- a/Grinder/common/serialization/SettingsContainer.cpp +++ b/Grinder/common/serialization/SettingsContainer.cpp @@ -6,8 +6,8 @@ #include "Grinder.h" #include "SettingsContainer.h" -SettingsContainer::SettingsContainer(QString name, bool isArray) : - _name{name}, _isArray{isArray} +SettingsContainer::SettingsContainer(QString name, bool isArray, bool isValueArray) : + _name{name}, _isArray{isArray}, _isValueArray{isArray && isValueArray} { } @@ -48,6 +48,9 @@ SettingsContainer& SettingsContainer::operator <<(const SettingsContainer& child SettingsContainer* SettingsContainer::createChildEx(QString name, bool unique, bool isArray) { + if (name.isEmpty()) + name = QString{"child%1"}.arg(_childContainers.size()); + if (unique) { if (auto childContainer = child(name)) diff --git a/Grinder/common/serialization/SettingsContainer.h b/Grinder/common/serialization/SettingsContainer.h index f518a75f0f7b7499dcd1cfc89436afc37f590f57..8ac49a6f267f311ac77dbc41650dfc54960985c1 100644 --- a/Grinder/common/serialization/SettingsContainer.h +++ b/Grinder/common/serialization/SettingsContainer.h @@ -15,7 +15,7 @@ namespace grndr class SettingsContainer { public: - SettingsContainer(QString name = "", bool isArray = false); + SettingsContainer(QString name = "", bool isArray = false, bool isValueArray = false); SettingsContainer(const SettingsContainer& container) = default; SettingsContainer(SettingsContainer&& container) = default; @@ -25,6 +25,7 @@ namespace grndr public: QString getName() const { return _name; } bool isArray() const { return _isArray; } + bool isValueArray() const { return _isValueArray; } bool isEmpty() const { return _childContainers.empty() && _values.isEmpty(); } @@ -42,6 +43,8 @@ namespace grndr SettingsContainer* child(QString name) { return _child<SettingsContainer*>(name); } const SettingsContainer* child(QString name) const { return _child<const SettingsContainer*>(name); } + SettingsContainer* child(int index) { return _child<SettingsContainer*>(QString{"child%1"}.arg(index)); } + const SettingsContainer* child(int index) const { return _child<const SettingsContainer*>(QString{"child%1"}.arg(index)); } const std::vector<SettingsContainer*> children() { return children(""); } const std::vector<const SettingsContainer*> children() const { return children(""); } const std::vector<SettingsContainer*> children(QString name) { return _children<SettingsContainer*>(name); } @@ -53,9 +56,13 @@ namespace grndr public: QVariant& value(QString name, QVariant defaultValue = QVariant{}); QVariant value(QString name, QVariant defaultValue = QVariant{}) const { return _values.value(name, defaultValue); } + QVariant& value(int index, QVariant defaultValue = QVariant{}) { return value(QString{"value%1"}.arg(index), defaultValue); } + QVariant value(int index, QVariant defaultValue = QVariant{}) const { return _values.value(QString{"value%1"}.arg(index), defaultValue); } QVariant& operator()(QString name, QVariant defaultValue = QVariant{}) { return value(name, defaultValue); } QVariant operator()(QString name, QVariant defaultValue = QVariant{}) const { return value(name, defaultValue); } + QVariant& operator()(int index, QVariant defaultValue = QVariant{}) { return value(index, defaultValue); } + QVariant operator()(int index, QVariant defaultValue = QVariant{}) const { return value(index, defaultValue); } const QVariantMap& values() const { return _values; } @@ -80,6 +87,7 @@ namespace grndr private: QString _name{""}; bool _isArray{false}; + bool _isValueArray{false}; std::vector<std::shared_ptr<SettingsContainer>> _childContainers; QVariantMap _values; diff --git a/Grinder/main.cpp b/Grinder/main.cpp index 8fddd05bb8e1f29a04afa5d98266f4ec80829562..997a12d973ae97993ffff2e341fbb8b72c512e7b 100644 --- a/Grinder/main.cpp +++ b/Grinder/main.cpp @@ -38,4 +38,6 @@ int main(int argc, char *argv[]) ShowExceptionMessage(""); throw; } + + return 0; } diff --git a/Grinder/project/ImageReference.cpp b/Grinder/project/ImageReference.cpp index b08b3a295ccf3e8141c1710f641ef038a51e9d99..e6b98f478ed718ad56f492747eedd483fbccd010 100644 --- a/Grinder/project/ImageReference.cpp +++ b/Grinder/project/ImageReference.cpp @@ -84,6 +84,7 @@ void ImageReference::serialize(SerializationContext& ctx) const void ImageReference::deserialize(DeserializationContext& ctx) { + // Deserialize values ctx.addImageReference(ctx.settings()(Serialization_Value_Index).toInt(), this); } diff --git a/Grinder/project/Label.cpp b/Grinder/project/Label.cpp index 948f7d50c9f0e7dcaa8ae65c3ea448e2aae95a62..717cfea516fe70c378ccee4a547425882b23b8f5 100644 --- a/Grinder/project/Label.cpp +++ b/Grinder/project/Label.cpp @@ -11,6 +11,7 @@ const char* Label::Serialization_Group_Layout = "Layout"; const char* Label::Serialization_Group_ImageBuildPool = "ImageBuildPool"; const char* Label::Serialization_Value_Name = "Name"; +const char* Label::Serialization_Value_Index = "Index"; Label::Label(Project* project, const std::shared_ptr<Pipeline>& pipeline) : ProjectItem(project), _pipeline{pipeline}, _imageBuildPool{this} @@ -23,6 +24,7 @@ void Label::serialize(SerializationContext& ctx) const { // Serialize values ctx.settings()(Serialization_Value_Name) = getName(); + ctx.settings()(Serialization_Value_Index) = ctx.addLabel(this); // Serialize the pipeline ctx.beginGroup(Serialization_Group_Pipeline); @@ -42,6 +44,9 @@ void Label::serialize(SerializationContext& ctx) const void Label::deserialize(DeserializationContext& ctx) { + // Deserialize values + ctx.addLabel(ctx.settings()(Serialization_Value_Index).toInt(), this); + // Deserialize the pipeline if (ctx.beginGroup(Serialization_Group_Pipeline)) { diff --git a/Grinder/project/Label.h b/Grinder/project/Label.h index 437f2809ee7bc855b327a91f8ddc8a0b87ed4196..601afdcc70e9b8b0c0cd98d53fd40bd7eb80b577 100644 --- a/Grinder/project/Label.h +++ b/Grinder/project/Label.h @@ -27,6 +27,7 @@ namespace grndr static const char* Serialization_Group_ImageBuildPool; static const char* Serialization_Value_Name; + static const char* Serialization_Value_Index; public: Label(Project* project, const std::shared_ptr<Pipeline>& pipeline); diff --git a/Grinder/project/Project.cpp b/Grinder/project/Project.cpp index 71761c3032d1c8e8dbce72da3b423333b0513946..ec42a75724606949e0bcb707a4721e358f7f6313 100644 --- a/Grinder/project/Project.cpp +++ b/Grinder/project/Project.cpp @@ -152,16 +152,6 @@ void Project::serialize(SerializationContext& ctx) const // Serialize values ctx.settings()(Serialization_Value_Name) = _name; - // Serialize the task pool - { - LongOperation opSerTaskPool{"Saving task pool", 1, false, true}; - opSerTaskPool.setStatusMessage("Task pool"); - - ctx.beginGroup(TaskPool::Serialization_Group); - _taskPool.serialize(ctx); - ctx.endGroup(); - } - // Serialize all image references (has to be done before the label serialization, since labels reference them) { LongOperation opSerImages{"Saving images", 1, false, true}; @@ -191,6 +181,16 @@ void Project::serialize(SerializationContext& ctx) const grinder()->imageEditorManager().serialize(ctx); ctx.endGroup(); } + + // Serialize the task pool + { + LongOperation opSerTaskPool{"Saving task pool", 1, false, true}; + opSerTaskPool.setStatusMessage("Task pool"); + + ctx.beginGroup(TaskPool::Serialization_Group); + _taskPool.serialize(ctx); + ctx.endGroup(); + } } void Project::deserialize(DeserializationContext& ctx) @@ -200,16 +200,6 @@ void Project::deserialize(DeserializationContext& ctx) // Deserialize values _name = ctx.settings()(Serialization_Value_Name).toString(); - // Deserialize the task pool - if (ctx.beginGroup(TaskPool::Serialization_Group)) - { - LongOperation opDeserTaskPool{"Loading task pool", 1, false, true}; - opDeserTaskPool.setStatusMessage("Task pool"); - - _taskPool.deserialize(ctx); - ctx.endGroup(); - } - // Deserialize all image references (has to be done before the label deserialization, since labels reference them) if (ctx.beginGroup(ImageReferenceVector::Serialization_Group)) { @@ -256,4 +246,14 @@ void Project::deserialize(DeserializationContext& ctx) grinder()->imageEditorManager().deserialize(ctx); ctx.endGroup(); } + + // Deserialize the task pool + if (ctx.beginGroup(TaskPool::Serialization_Group)) + { + LongOperation opDeserTaskPool{"Loading task pool", 1, false, true}; + opDeserTaskPool.setStatusMessage("Task pool"); + + _taskPool.deserialize(ctx); + ctx.endGroup(); + } } diff --git a/Grinder/task/TaskCatalog.cpp b/Grinder/task/TaskCatalog.cpp index 18a51aac3a6e3075b7a1d678afad9917075194c3..cb76a86de221c40927aa1f5ecc7fe76e46a24bb1 100644 --- a/Grinder/task/TaskCatalog.cpp +++ b/Grinder/task/TaskCatalog.cpp @@ -9,6 +9,7 @@ #include "tasks/GenericTask.h" #include "tasks/BaristaTrainingTask.h" +#include "tasks/BaristaInferenceTask.h" #define REGISTER_TASK_TYPE(cls) registerTaskType(cls::type_value, [](TaskPool* taskPool, QString name) { return std::make_unique<cls>(taskPool, name); }) @@ -58,4 +59,5 @@ void TaskCatalog::registerStandardTasks() { REGISTER_TASK_TYPE(GenericTask); REGISTER_TASK_TYPE(BaristaTrainingTask); + REGISTER_TASK_TYPE(BaristaInferenceTask); } diff --git a/Grinder/task/TaskType.cpp b/Grinder/task/TaskType.cpp index 3ac136c22f31953a26536dca2d9a3c3f425de2f2..01e6bcfd9d70bb622365ef2d72a738df406f404e 100644 --- a/Grinder/task/TaskType.cpp +++ b/Grinder/task/TaskType.cpp @@ -11,3 +11,4 @@ const char* TaskType::Undefined = ""; const char* TaskType::Generic = "Generic"; const char* TaskType::BaristaTraining = "BaristaTraining"; +const char* TaskType::BaristaInference = "BaristaInference"; diff --git a/Grinder/task/TaskType.h b/Grinder/task/TaskType.h index ef4b5d91249caf96bbdbf79cf4ee9a219ea7c3f2..928b504dcf751a38cd62c5e420bab8d423a3bcb7 100644 --- a/Grinder/task/TaskType.h +++ b/Grinder/task/TaskType.h @@ -18,6 +18,7 @@ namespace grndr static const char* Generic; static const char* BaristaTraining; + static const char* BaristaInference; public: using QString::QString; diff --git a/Grinder/task/tasks/BaristaInferenceTask.cpp b/Grinder/task/tasks/BaristaInferenceTask.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fc0c9a8d62d1e7a0358afc91e17fb04576838c5a --- /dev/null +++ b/Grinder/task/tasks/BaristaInferenceTask.cpp @@ -0,0 +1,415 @@ +/****************************************************************************** + * File: BaristaInferenceTask.cpp + * Date: 07.12.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "BaristaInferenceTask.h" +#include "barista/BaristaMessage.h" +#include "barista/BaristaProtocol.h" +#include "barista/BaristaBinaryRecvData.h" +#include "barista/BaristaBinarySendData.h" +#include "core/GrinderApplication.h" +#include "task/TaskExceptions.h" +#include "ui/task/tasks/BaristaInferenceTaskWidget.h" + +const TaskType BaristaInferenceTask::type_value = TaskType::BaristaInference; + +const char* BaristaInferenceTask::Serialization_Value_NetworkPath = "NetworkPath"; +const char* BaristaInferenceTask::Serialization_Value_ModelPath = "ModelPath"; +const char* BaristaInferenceTask::Serialization_Value_Label = "Label"; +const char* BaristaInferenceTask::Serialization_Value_OutputBlock = "OutputBlock"; +const char* BaristaInferenceTask::Serialization_Value_ImageReferences = "ImageReferences"; + +BaristaInferenceTask::BaristaInferenceTask(TaskPool* taskPool, QString name) : BaristaTask(taskPool, type_value, name) +{ + // Listen for removed image references to remove them here as well + connect(&grinder()->project(), &Project::imageReferenceRemoved, this, &BaristaInferenceTask::imageReferenceRemoved); +} + +ConfigureTaskWidgetBase* BaristaInferenceTask::createEditor(bool newTask, QWidget* parent) +{ + return new BaristaInferenceTaskWidget{this, newTask, parent}; +} + +void BaristaInferenceTask::setLabel(Label* label) +{ + if (_label) + disconnect(&grinder()->project(), nullptr, this, nullptr); + + _label = label; + + // Listen to removed labels to reset the assigned label if necessary + if (_label) + connect(&grinder()->project(), &Project::labelRemoved, this, &BaristaInferenceTask::labelRemoved); +} + +void BaristaInferenceTask::setOutputBlock(Block* block) +{ + if (_outputBlock) + disconnect(_outputBlock->pipeline(), nullptr, this, nullptr); + + _outputBlock = block; + + // Listen to removed blocks to reset the assigned output block if necessary + if (_outputBlock) + connect(_outputBlock->pipeline(), &Pipeline::blockRemoved, this, &BaristaInferenceTask::blockRemoved); +} + +void BaristaInferenceTask::setImageReferences(const std::vector<const ImageReference*>& imageRefs) +{ + _imageReferences = imageRefs; +} + +void BaristaInferenceTask::serialize(SerializationContext& ctx) const +{ + BaristaTask::serialize(ctx); + + // Serialize values + ctx.settings()(Serialization_Value_NetworkPath) = _networkPath; + ctx.settings()(Serialization_Value_ModelPath) = _modelPath; + ctx.settings()(Serialization_Value_Label) = ctx.getLabelIndex(_label); + ctx.settings()(Serialization_Value_OutputBlock) = ctx.getBlockIndex(_outputBlock); + + // Serialize image references + QStringList imageReferences; + + for (auto imageRef : _imageReferences) + imageReferences << QString{"%1"}.arg(ctx.getImageReferenceIndex(imageRef)); + + ctx.settings()(Serialization_Value_ImageReferences) = imageReferences.join(","); +} + +void BaristaInferenceTask::deserialize(DeserializationContext& ctx) +{ + BaristaTask::deserialize(ctx); + + // Deserialize values + _networkPath = ctx.settings()(Serialization_Value_NetworkPath).toString(); + _modelPath = ctx.settings()(Serialization_Value_ModelPath).toString(); + setLabel(ctx.getLabel(ctx.settings()(Serialization_Value_Label, -1).toInt())); + setOutputBlock(ctx.getBlock(ctx.settings()(Serialization_Value_OutputBlock, -1).toInt())); + + // Deserialize image references + for (auto imageRef : ctx.settings()(Serialization_Value_ImageReferences).toString().split(",")) + { + auto imageRefIndex = imageRef.toInt(); + + if (imageRefIndex != -1) + { + if (auto imageRef = ctx.getImageReference(imageRefIndex)) + _imageReferences.push_back(imageRef); + } + } +} + +void BaristaInferenceTask::baristaReady() +{ + // Begin by requesting a well-defined test data array + changeTaskState(InferenceTaskState::GetTestData, "Retrieving test data..."); + sendGetTestDataMessage(); +} + +bool BaristaInferenceTask::handleReplyMessage(const SettingsContainer& messageData) +{ + if (BaristaTask::handleReplyMessage(messageData)) + return true; + + return false; +} + +bool BaristaInferenceTask::handleSubscriberMessage(const SettingsContainer& messageData) +{ + if (BaristaTask::handleSubscriberMessage(messageData)) + return true; + + switch (_taskState) + { + case InferenceTaskState::GetTestData: + handleGetTestData(messageData); + return true; + + case InferenceTaskState::LoadNetwork: + handleLoadNetwork(messageData); + return true; + + case InferenceTaskState::InferImages: + handleInferImages(messageData); + return true; + } + + return false; +} + +void BaristaInferenceTask::handleGetTestData(const SettingsContainer& messageData) +{ + BaristaMessage message{messageData}; + + if (message.isMessage(BARISTA_COMMAND_GETTESTARRAY)) + { + if (checkMessageStatus(message)) + { + bool testDataValid = false; + + if (auto data = message.payload().child("data")) + { + if (auto testArray = data->child("test_array")) + { + BaristaBinaryRecvData binaryData{*testArray, *_subscriberSocket.get()}; + + if (auto rawData = binaryData.data("raw_data")) + { + if (rawData->size() == 2 * 3 * 4 * 4) // 2 x 3 x 4 dimensions, 4 bytes per element + { + if (binaryData.checkDataType("float32")) + { + if (binaryData.checkDimensions({2, 3, 4})) + { + auto matrix = binaryData.matrix("raw_data"); + + if (matrix.at<float>(cv::Vec3i{0, 1, 0}) == 4.0f && matrix.at<float>(cv::Vec3i{0, 1, 2}) == 6.0f && matrix.at<float>(cv::Vec3i{1, 1, 2}) == 18.0f) + testDataValid = true; + else + addMessageLog("\tThe received data does not match"); + } + else + addMessageLog("\tThe shape dimensions do not match"); + } + else + addMessageLog("\tUnexpected data format"); + } + else + addMessageLog("\tInvalid data size"); + } + else + addMessageLog("\tNo raw data at index 0"); + } + } + + if (testDataValid) + { + addMessageLog("\tTest data received and verified"); + + // Load the trained network + changeTaskState(InferenceTaskState::LoadNetwork, QString{"Loading trained network '%1'..."}.arg(_networkPath)); + sendLoadNetworkMessage(); + } + else + reportBaristaError("\tFailed to receive and verify the test data", &message); + } + else + reportBaristaError("\tFailed to start the training", &message); + } +} + +void BaristaInferenceTask::handleLoadNetwork(const SettingsContainer& messageData) +{ + BaristaMessage message{messageData}; + + if (message.isMessage(BARISTA_COMMAND_LOADNETWORK)) + { + if (checkMessageStatus(message)) + { + // Extract information about in- and outputs from the network + if (auto data = message.payload().child("data")) + { + auto extractInfo = [&data](std::vector<NetworkData>& nwDataVec, QString name) { + if (auto inputs = data->child(name)) + { + for (auto input : inputs->children()) + { + NetworkData nwData; + nwData.name = input->value(0).toString(); // We assume that the container contains exactly one value (the name) + + if (auto dims = input->child(0)) // We assume that the container contains exactly one child (the dimensions) + { + for (auto dim : dims->values()) + nwData.dimensions.push_back(dim.toInt()); + } + + nwDataVec.push_back(std::move(nwData)); + } + } + }; + + extractInfo(_networkInputs, "inputs"); + extractInfo(_networkOutputs, "outputs"); + } + + addMessageLog("\tNetwork loaded"); + addMessageLog("", false); + + changeTaskState(InferenceTaskState::InferImages, "Starting inference..."); + sendInferImageMessage(0); + } + else + reportBaristaError("\tFailed to load the trained network", &message); + } +} + +void BaristaInferenceTask::handleInferImages(const SettingsContainer& messageData) +{ + BaristaMessage message{messageData}; + + if (message.isMessage(BARISTA_COMMAND_INFER)) + { + if (checkMessageStatus(message)) + { + QString result; + + // Get all output results + for (auto output : _networkOutputs) + { + if (auto data = message.payload().child("data")) + { + if (auto outputData = data->child(output.name)) + { + BaristaBinaryRecvData binaryData{*outputData, *_subscriberSocket.get()}; + result = formatInferenceResult(binaryData); + } + } + } + + addMessageLog(QString{"\t\tSucceeded [%1]"}.arg(result), false); + + // Send the next infer message; if the index goes out of bounds, the task will be finished + sendInferImageMessage(_currentInferImageIndex + 1); + } + else + reportBaristaError("\t\tFailed", &message); + } +} + +void BaristaInferenceTask::sendGetTestDataMessage() +{ + BaristaMessage message{BARISTA_COMMAND_GETTESTARRAY}; + sendMessage(message, _requestSocket); +} + +void BaristaInferenceTask::sendLoadNetworkMessage() +{ + BaristaMessage message{BARISTA_COMMAND_LOADNETWORK}; + SettingsContainer data{"data"}; + data("network") = _networkPath; + data("model") = _modelPath; + message.payload() << std::move(data); + sendMessage(message, _requestSocket); +} + +void BaristaInferenceTask::sendInferImageMessage(unsigned int inferImageIndex) +{ + if (inferImageIndex < _imageReferences.size()) // Any images left to infer? + { + auto imageRef = _imageReferences[inferImageIndex]; + addMessageLog(QString{"\tPerforming inference on '%1'..."}.arg(imageRef->getImageFileName())); + + BaristaMessage message{BARISTA_COMMAND_INFER}; + BaristaBinarySendData binaryData{_label, _outputBlock, imageRef}; + SettingsContainer data{"data"}; + + for (unsigned int inputIndex = 0; inputIndex < _networkInputs.size(); ++inputIndex) + { + auto input = _networkInputs[inputIndex]; + + // Add settings for the current network input + SettingsContainer inputData{input.name}; + inputData.value("raw_data") = inputIndex; + inputData.value("dtype") = BaristaBinaryData::getBaristaDataType(CV_32FC1); + + SettingsContainer shape{"shape", true, true}; + + // We have to make the following assumptions about the shape of the data to send: + // 1. The shape has exactly 4 dimensions + // 2. The first dimension indicates the number of images + // 3. The second one indicates the number of channels + // 4. The third and fourth one represent the height and width of the input data + + if (input.dimensions.size() == 4) + { + int images = 1; // Force the first dimension to always be 1, as we are sending exactly one image + int channels = input.dimensions[1]; + int height = input.dimensions[2]; + int width = input.dimensions[3]; + + if (channels != 1 && channels != 3) // We only support grayscale or color images + channels = 1; + + shape.value(0) = images; + shape.value(1) = channels; + shape.value(2) = height; + shape.value(3) = width; + + inputData << std::move(shape); + data << std::move(inputData); + + // Add a new entry in the desired shape to the binary message + binaryData.addData(QSize{width, height}, channels == 3); + } + else + reportBaristaError(QString{"Invalid input data dimensions for '%1'"}.arg(input.name)); + } + + // Send the infer message + message.payload() << std::move(data); + sendMessage(message, _requestSocket, true, true, binaryData.getMessageCount() > 0); + + // And send the binary data for each input + binaryData.send(*_requestSocket.get()); + + _currentInferImageIndex = inferImageIndex; + } + else // All images have been processed + { + addMessageLog("All images have been processed"); + + // The inference has finished, so break the Barista connection and finish the task + shutdownBaristaConnection(); + finishTask(true); + } +} + +QString BaristaInferenceTask::formatInferenceResult(const BaristaBinaryRecvData& result) const +{ + QStringList resultList; + + if (result.checkDataType(CV_32FC1)) // Result must be floats + { + auto dims = result.getDimensions(); + + if (dims.size() == 2 && dims[0] == 1) // Dimensions of the result must be 1x? + { + for (auto dataName : result.getDataNames()) + { + QStringList currentResult; + auto matrix = result.matrix(dataName); + + for (int i = 0; i < dims[1]; ++i) + currentResult << QString{"Class %1: %2"}.arg(i).arg(matrix.at<float>(0, i), 0, 'g', 3); + + resultList << currentResult.join(", "); + } + } + } + + if (resultList.isEmpty()) + resultList << "Invalid result"; + + return resultList.join("; "); +} + +void BaristaInferenceTask::labelRemoved(const std::shared_ptr<Label>& label) +{ + if (label.get() == _label) + setLabel(nullptr); +} + +void BaristaInferenceTask::blockRemoved(const std::shared_ptr<Block>& block) +{ + if (block.get() == _outputBlock) + setOutputBlock(nullptr); +} + +void BaristaInferenceTask::imageReferenceRemoved(const std::shared_ptr<ImageReference>& imageRef) +{ + _imageReferences.erase(std::remove_if(_imageReferences.begin(), _imageReferences.end(), [&imageRef](auto imgRef) { return imgRef == imageRef.get(); }), _imageReferences.end()); +} diff --git a/Grinder/task/tasks/BaristaInferenceTask.h b/Grinder/task/tasks/BaristaInferenceTask.h new file mode 100644 index 0000000000000000000000000000000000000000..c39704c4a11a11a057c00af4ce5541db2c561ce1 --- /dev/null +++ b/Grinder/task/tasks/BaristaInferenceTask.h @@ -0,0 +1,105 @@ +/****************************************************************************** + * File: BaristaInferenceTask.h + * Date: 07.12.2018 + *****************************************************************************/ + +#ifndef BARISTAINFERENCETASK_H +#define BARISTAINFERENCETASK_H + +#include "BaristaTask.h" + +namespace grndr +{ + class Label; + class BaristaBinaryRecvData; + + class BaristaInferenceTask : public BaristaTask + { + Q_OBJECT + + public: + static const TaskType type_value; + + static const char* Serialization_Value_NetworkPath; + static const char* Serialization_Value_ModelPath; + static const char* Serialization_Value_Label; + static const char* Serialization_Value_OutputBlock; + static const char* Serialization_Value_ImageReferences; + + public: + BaristaInferenceTask(TaskPool* taskPool, QString name = ""); + + public: + virtual ConfigureTaskWidgetBase* createEditor(bool newTask, QWidget* parent) override; + + public: + QString getNetworkPath() const { return _networkPath; } + void setNetworkPath(QString path) { _networkPath = path; } + QString getModelPath() const { return _modelPath; } + void setModelPath(QString path) { _modelPath = path; } + + Label* getLabel() const { return _label; } + void setLabel(Label* label); + Block* getOutputBlock() const { return _outputBlock; } + void setOutputBlock(Block* block); + std::vector<const ImageReference*> getImageReferences() const { return _imageReferences; } + void setImageReferences(const std::vector<const ImageReference*>& imageRefs); + + public: + virtual void serialize(SerializationContext& ctx) const override; + virtual void deserialize(DeserializationContext& ctx) override; + + protected: + virtual void baristaReady() override; + + virtual bool handleReplyMessage(const SettingsContainer& messageData) override; + virtual bool handleSubscriberMessage(const SettingsContainer& messageData) override; + + private: + void handleGetTestData(const SettingsContainer& messageData); + void handleLoadNetwork(const SettingsContainer& messageData); + void handleInferImages(const SettingsContainer& messageData); + + void sendGetTestDataMessage(); + void sendLoadNetworkMessage(); + void sendInferImageMessage(unsigned int inferImageIndex); + + private: + QString formatInferenceResult(const grndr::BaristaBinaryRecvData& result) const; + + private slots: + void labelRemoved(const std::shared_ptr<Label>& label); + void blockRemoved(const std::shared_ptr<Block>& block); + void imageReferenceRemoved(const std::shared_ptr<ImageReference>& imageRef); + + private: + enum InferenceTaskState + { + GetTestData = TypeSpecificBase, + LoadNetwork, + InferImages, + }; + + private: + QString _networkPath{""}; + QString _modelPath{""}; + + Label* _label{nullptr}; + Block* _outputBlock{nullptr}; + std::vector<const ImageReference*> _imageReferences; + + private: + struct NetworkData + { + QString name{""}; + std::vector<int> dimensions; + }; + + std::vector<NetworkData> _networkInputs; + std::vector<NetworkData> _networkOutputs; + + unsigned int _currentInferImageIndex{0}; + }; +} + +#endif diff --git a/Grinder/task/tasks/BaristaTask.cpp b/Grinder/task/tasks/BaristaTask.cpp index 7b4a34cb9caad1cc1bba240a46e0e5210f458d2d..19b528769af748facba6d5388eb9544e47330a7e 100644 --- a/Grinder/task/tasks/BaristaTask.cpp +++ b/Grinder/task/tasks/BaristaTask.cpp @@ -5,9 +5,9 @@ #include "Grinder.h" #include "BaristaTask.h" -#include "BaristaMessage.h" -#include "BaristaProtocol.h" #include "task/TaskExceptions.h" +#include "barista/BaristaMessage.h" +#include "barista/BaristaProtocol.h" #include "common/serialization/JsonSettingsCodec.h" #include "common/serialization/SerializationExceptions.h" @@ -56,25 +56,15 @@ void BaristaTask::initiateBaristaConnection() try { _context = std::make_unique<zmq::context_t>(1); - auto createSocket = [this](int type) { - auto socket = std::make_unique<zmq::socket_t>(*_context, type); - - // Make sure that all pending messages are sent on close - int linger = -1; - socket->setsockopt(ZMQ_LINGER, &linger, sizeof(linger)); - - return socket; - }; - - _replySocket = createSocket(ZMQ_REP); - _subscriberSocket = createSocket(ZMQ_SUB); - _requestSocket = createSocket(ZMQ_REQ); + _replySocket = std::make_unique<BaristaSocket>(*_context, ZMQ_REP); + _subscriberSocket = std::make_unique<BaristaSocket>(*_context, ZMQ_SUB); + _requestSocket = std::make_unique<BaristaSocket>(*_context, ZMQ_REQ); // The subscriber should just receive everything - _subscriberSocket->setsockopt(ZMQ_SUBSCRIBE, "", 0); + _subscriberSocket->subscribeToEverything(); // Bind the reply port - _replySocket->bind(QString{"tcp://*:%1"}.arg(_baristaPort).toStdString()); + _replySocket->bindToPort(_baristaPort); } catch (std::exception& e) { throw TaskException{this, _EXCPT(QString{"Unable to create the ZMQ objects (%1)"}.arg(e.what()))}; } @@ -133,7 +123,7 @@ bool BaristaTask::encodeMessage(zmq::message_t& message, const SettingsContainer // Use a std::string as QString uses a 16bit representation of chars std::string messageString = messageData.toStdString(); message.rebuild(messageString.size()); - memcpy(message.data(), messageString.data(), messageString.size()); + std::memcpy(message.data(), messageString.data(), messageString.size()); } catch (SerializationException& e) { addMessageLog(QString{"! Unable to encode a message (%1)"}.arg(GetExceptionMessage(e.what()))); return false; @@ -166,26 +156,13 @@ bool BaristaTask::decodeMessage(const zmq::message_t& message, SettingsContainer return true; } -void BaristaTask::sendMessage(const SettingsContainer& settings, std::unique_ptr<zmq::socket_t>& socket, bool receiveAck, bool blocking) +void BaristaTask::sendMessage(const SettingsContainer& settings, std::unique_ptr<BaristaSocket>& socket, bool receiveAck, bool blocking, bool sendMore) { // Just encode the message and send it over the socket zmq::message_t message; if (encodeMessage(message, settings)) - { - try { - socket->send(message, blocking ? ZMQ_DONTWAIT : 0); - } catch (std::exception& e) { - // Ignore ZMQ errors here - } - - if (receiveAck) - { - // Receive a corresponding ACK message from Barista; we are not really interested in it, though, so we just ignore it - zmq::message_t ackMessage; - socket->recv(&ackMessage); - } - } + socket->sendMessage(message, receiveAck, blocking, sendMore); } bool BaristaTask::checkMessageStatus(BaristaMessage& message) const @@ -230,7 +207,7 @@ void BaristaTask::update() } } -void BaristaTask::pollMessage(std::unique_ptr<zmq::socket_t>& socket, std::function<void(BaristaTask*, const SettingsContainer&)> callback) +void BaristaTask::pollMessage(std::unique_ptr<BaristaSocket>& socket, std::function<void(BaristaTask*, const SettingsContainer&)> callback) { // Poll the given socket for a message zmq::pollitem_t pollItems[] = {{*socket, 0, ZMQ_POLLIN, 0}}; @@ -240,7 +217,7 @@ void BaristaTask::pollMessage(std::unique_ptr<zmq::socket_t>& socket, std::funct { // Some data is waiting on the port, so receive it zmq::message_t message; - socket->recv(&message); + socket->recvMessage(message); if (callback && message.size() > 0) { @@ -253,6 +230,9 @@ void BaristaTask::pollMessage(std::unique_ptr<zmq::socket_t>& socket, std::funct } catch (TaskException& e) { // Show errors but ignore them otherwise addMessageLog(QString{"! %1"}.arg(GetExceptionMessage(e.what()))); + } catch (std::exception& e) { + // Show errors but ignore them otherwise + addMessageLog(QString{"! %1"}.arg(e.what())); } } } diff --git a/Grinder/task/tasks/BaristaTask.h b/Grinder/task/tasks/BaristaTask.h index d8978a544a7013bb1e9f0b11e861ef7914bf09b8..27f8d5791e1d8e0a296c8a41d4e168cb2d8b70a7 100644 --- a/Grinder/task/tasks/BaristaTask.h +++ b/Grinder/task/tasks/BaristaTask.h @@ -8,6 +8,7 @@ #include <zmq.hpp> +#include "barista/BaristaSocket.h" #include "task/Task.h" namespace grndr @@ -50,7 +51,7 @@ namespace grndr protected: bool encodeMessage(zmq::message_t& message, const SettingsContainer& settings); bool decodeMessage(const zmq::message_t& message, SettingsContainer& settings); - void sendMessage(const SettingsContainer& settings, std::unique_ptr<zmq::socket_t>& socket, bool receiveAck = true, bool blocking = true); + void sendMessage(const SettingsContainer& settings, std::unique_ptr<BaristaSocket>& socket, bool receiveAck = true, bool blocking = true, bool sendMore = false); bool checkMessageStatus(BaristaMessage& message) const; @@ -65,7 +66,7 @@ namespace grndr virtual void update() override; private: - void pollMessage(std::unique_ptr<zmq::socket_t>& socket, std::function<void(BaristaTask*, const SettingsContainer&)> callback); + void pollMessage(std::unique_ptr<BaristaSocket>& socket, std::function<void(BaristaTask*, const SettingsContainer&)> callback); private: void handleAwaitingConnection(const SettingsContainer& messageData); @@ -105,9 +106,9 @@ namespace grndr protected: std::unique_ptr<zmq::context_t> _context; - std::unique_ptr<zmq::socket_t> _replySocket; - std::unique_ptr<zmq::socket_t> _subscriberSocket; - std::unique_ptr<zmq::socket_t> _requestSocket; + std::unique_ptr<BaristaSocket> _replySocket; + std::unique_ptr<BaristaSocket> _subscriberSocket; + std::unique_ptr<BaristaSocket> _requestSocket; }; } diff --git a/Grinder/task/tasks/BaristaTrainingTask.cpp b/Grinder/task/tasks/BaristaTrainingTask.cpp index cd0f1ec8b7c1da090111e13ee49dd9dd01daba0b..49d1f59a451bd9a2628f8a2f7e1c783535ef8587 100644 --- a/Grinder/task/tasks/BaristaTrainingTask.cpp +++ b/Grinder/task/tasks/BaristaTrainingTask.cpp @@ -5,8 +5,8 @@ #include "Grinder.h" #include "BaristaTrainingTask.h" -#include "BaristaMessage.h" -#include "BaristaProtocol.h" +#include "barista/BaristaMessage.h" +#include "barista/BaristaProtocol.h" #include "ui/task/tasks/BaristaTrainingTaskWidget.h" const TaskType BaristaTrainingTask::type_value = TaskType::BaristaTraining; @@ -50,7 +50,7 @@ void BaristaTrainingTask::pause(bool setPause) void BaristaTrainingTask::baristaReady() { // Start the training - changeTaskState(WorkerTaskState::StartTraining, "Starting training..."); + changeTaskState(TrainingTaskState::StartTraining, QString{"Starting training on '%1'..."}.arg(_solverPath)); sendStartTrainingMessage(); } @@ -69,11 +69,11 @@ bool BaristaTrainingTask::handleSubscriberMessage(const SettingsContainer& messa switch (_taskState) { - case WorkerTaskState::StartTraining: - handleStartingTraining(messageData); + case TrainingTaskState::StartTraining: + handleStartTraining(messageData); return true; - case WorkerTaskState::Training: + case TrainingTaskState::Training: handleTraining(messageData); return true; } @@ -81,7 +81,7 @@ bool BaristaTrainingTask::handleSubscriberMessage(const SettingsContainer& messa return false; } -void BaristaTrainingTask::handleStartingTraining(const SettingsContainer& messageData) +void BaristaTrainingTask::handleStartTraining(const SettingsContainer& messageData) { BaristaMessage message{messageData}; @@ -93,7 +93,7 @@ void BaristaTrainingTask::handleStartingTraining(const SettingsContainer& messag addMessageLog("", false); // Just let the training run... - changeTaskState(WorkerTaskState::Training); + changeTaskState(TrainingTaskState::Training); } else reportBaristaError("\tFailed to start the training", &message); @@ -139,13 +139,11 @@ void BaristaTrainingTask::sendStartTrainingMessage() data("dir") = _sessionPath; data("solver") = _solverPath; message.payload() << std::move(data); - sendMessage(message, _requestSocket); } void BaristaTrainingTask::sendPauseTrainingMessage(bool setPause) { BaristaMessage message{setPause ? BARISTA_COMMAND_PAUSETRAINING : BARISTA_COMMAND_RESUMETRAINING}; - sendMessage(message, _requestSocket); } diff --git a/Grinder/task/tasks/BaristaTrainingTask.h b/Grinder/task/tasks/BaristaTrainingTask.h index 72ffe74c813334337f933500c46388436115d601..964794c550f1d91dfc55f75f9e4b8c65bc760b8a 100644 --- a/Grinder/task/tasks/BaristaTrainingTask.h +++ b/Grinder/task/tasks/BaristaTrainingTask.h @@ -45,21 +45,21 @@ namespace grndr virtual bool handleReplyMessage(const SettingsContainer& messageData) override; virtual bool handleSubscriberMessage(const SettingsContainer& messageData) override; - protected: - void handleStartingTraining(const SettingsContainer& messageData); + private: + void handleStartTraining(const SettingsContainer& messageData); void handleTraining(const SettingsContainer& messageData); void sendStartTrainingMessage(); void sendPauseTrainingMessage(bool setPause); - protected: - enum WorkerTaskState + private: + enum TrainingTaskState { StartTraining = TypeSpecificBase, Training, }; - protected: + private: QString _sessionPath{""}; QString _solverPath{""}; }; diff --git a/Grinder/ui/UIUtils.h b/Grinder/ui/UIUtils.h index 40946daf418281a9ba1c41e68108c92a731706b1..5e9526bee81b36987283d23f93784fda162aa7b1 100644 --- a/Grinder/ui/UIUtils.h +++ b/Grinder/ui/UIUtils.h @@ -22,7 +22,7 @@ namespace grndr static QString askFileName(bool saveFileName, QString dlgName, QWidget* parent = nullptr, QString caption = "", QString filter = "", QString* selectedFilter = nullptr, QFileDialog::Options options = QFileDialog::Options{}); static QStringList askFileNames(QString dlgName, QWidget* parent = nullptr, QString caption = "", QString filter = "", QString* selectedFilter = nullptr, QFileDialog::Options options = QFileDialog::Options{}); - static void removeChildrenFromLayout(QLayout* layout); + static void removeChildrenFromLayout(QLayout* layout); private: UIUtils() { } diff --git a/Grinder/ui/dlg/HDF5ExportDialog.cpp b/Grinder/ui/dlg/HDF5ExportDialog.cpp index b917ca6b46db486449bd2805f4f56e583b5e3c6e..72b42a50cc342a2a598a2481480dc9c7f19f4ffd 100644 --- a/Grinder/ui/dlg/HDF5ExportDialog.cpp +++ b/Grinder/ui/dlg/HDF5ExportDialog.cpp @@ -9,9 +9,6 @@ #include "core/GrinderApplication.h" #include "project/Project.h" -Q_DECLARE_METATYPE(Label*); -Q_DECLARE_METATYPE(Block*); - HDF5ExportDialog::HDF5ExportDialog(const Project* project, QWidget *parent) : QDialog(parent, Qt::Dialog|Qt::WindowTitleHint|Qt::WindowCloseButtonHint), ui{new Ui::HDF5ExportDialog} { @@ -28,17 +25,17 @@ HDF5ExportDialog::~HDF5ExportDialog() Label* HDF5ExportDialog::getLabel() const { - return ui->lstLabel->currentData().value<Label*>(); + return ui->lstLabel->getSelectedLabel(); } Block* HDF5ExportDialog::getOutputBlock() const { - return ui->lstOutputBlock->currentData().value<Block*>(); + return ui->lstOutputBlock->getSelectedOutputBlock(); } std::vector<const ImageReference*> HDF5ExportDialog::getImageReferences() const { - return _imageReferencesListWidget->getCheckedObjects<const ImageReference>(); + return ui->lstImageReferences->getSelectedImageReferences(); } bool HDF5ExportDialog::exportImageTags() const @@ -50,76 +47,34 @@ void HDF5ExportDialog::setupUi() { ui->setupUi(this); - _imageReferencesListWidget = new CheckListWidget<ImageReference, ImageReferencesListItem>{}; - ui->imagesBoxLayout->addWidget(_imageReferencesListWidget); - // The OK button should only be enabled if at least one image reference is checked - connect(_imageReferencesListWidget, &QListWidget::itemChanged, this, &HDF5ExportDialog::imageReferencesListItemChanged); + connect(ui->lstImageReferences, &QListWidget::itemChanged, this, &HDF5ExportDialog::imageReferencesListItemChanged); updateUi(); } void HDF5ExportDialog::updateUi() { - ui->buttonBox->button(QDialogButtonBox::Ok)->setEnabled(_imageReferencesListWidget->getCheckedObjects().size() > 0); + ui->buttonBox->button(QDialogButtonBox::Ok)->setEnabled(ui->lstImageReferences->getCheckedObjects().size() > 0); } void HDF5ExportDialog::fillLabels(const LabelVector& labels) { - int index = 0; - - // Add all labels - for (const auto& label : labels) - { - ui->lstLabel->addItem(label->getName(), QVariant::fromValue(label.get())); - - if (label.get() == grinder()->projectController().activeLabel()) - index = ui->lstLabel->count() - 1; - } - - if (ui->lstLabel->count() == 0) - { - ui->lstLabel->addItem("No labels to display", QVariant::fromValue<Label*>(nullptr)); - ui->lstLabel->setEnabled(false); - } - - ui->lstLabel->setCurrentIndex(index); + ui->lstLabel->populate(labels); } void HDF5ExportDialog::fillOutputBlocks(const Label* label) { - ui->lstOutputBlock->clear(); - - // The first item always represents the raw input images - ui->lstOutputBlock->addItem("None (export unaltered images)", QVariant::fromValue<Block*>(nullptr)); - - auto boldFont = font(); - boldFont.setBold(true); - ui->lstOutputBlock->setItemData(0, boldFont, Qt::FontRole); - - ui->lstOutputBlock->insertSeparator(1); - - if (label) - { - // Add all output blocks - for (const auto& block : label->pipeline()->blocks()) - { - if (block->getType() == BlockType::Output) - ui->lstOutputBlock->addItem(block->getName(), QVariant::fromValue<Block*>(block.get())); - } - } - - ui->lstOutputBlock->setCurrentIndex(0); + ui->lstOutputBlock->populate(label, "None (export unaltered images)"); } void HDF5ExportDialog::fillImageReferences(const ImageReferenceVector& imageRefs) { - _imageReferencesListWidget->populateList(imageRefs); - _imageReferencesListWidget->setCheckedObjects(imageRefs.toVector()); + ui->lstImageReferences->populate(imageRefs); } -void grndr::HDF5ExportDialog::on_lstLabel_currentIndexChanged(int index) +void HDF5ExportDialog::on_lstLabel_currentIndexChanged(int index) { Q_UNUSED(index); - fillOutputBlocks(ui->lstLabel->currentData().value<Label*>()); + fillOutputBlocks(ui->lstLabel->getSelectedLabel()); } diff --git a/Grinder/ui/dlg/HDF5ExportDialog.h b/Grinder/ui/dlg/HDF5ExportDialog.h index a17952239ac13319a9bb90e42b70eb909ed1b2b7..b74d4d82bb2dfc37b03420fb07f948a19e71e681 100644 --- a/Grinder/ui/dlg/HDF5ExportDialog.h +++ b/Grinder/ui/dlg/HDF5ExportDialog.h @@ -9,7 +9,7 @@ #include <QDialog> #include "project/ImageReference.h" -#include "ui/widgets/CheckListWidget.h" +#include "ui/widgets/ImageReferencesCheckListWidget.h" #include "ui/mainwnd/ImageReferencesListItem.h" namespace Ui @@ -20,8 +20,8 @@ namespace Ui namespace grndr { class Project; - class LabelVector; class Label; + class LabelVector; class ImageReferenceVector; class HDF5ExportDialog : public QDialog @@ -54,9 +54,6 @@ namespace grndr void fillLabels(const LabelVector& labels); void fillOutputBlocks(const Label* label); void fillImageReferences(const ImageReferenceVector& imageRefs); - - private: - CheckListWidget<ImageReference, ImageReferencesListItem>* _imageReferencesListWidget{nullptr}; }; } diff --git a/Grinder/ui/dlg/HDF5ExportDialog.ui b/Grinder/ui/dlg/HDF5ExportDialog.ui index 8c97a3586e967f23f56d09dcea04673332102ef6..da296bcd0bf3620fb0e50f7420b503a008d67a41 100644 --- a/Grinder/ui/dlg/HDF5ExportDialog.ui +++ b/Grinder/ui/dlg/HDF5ExportDialog.ui @@ -40,7 +40,7 @@ </widget> </item> <item row="1" column="1"> - <widget class="QComboBox" name="lstOutputBlock"> + <widget class="OutputBlocksComboBox" name="lstOutputBlock"> <property name="sizePolicy"> <sizepolicy hsizetype="MinimumExpanding" vsizetype="Fixed"> <horstretch>0</horstretch> @@ -72,7 +72,7 @@ </widget> </item> <item row="0" column="1"> - <widget class="QComboBox" name="lstLabel"> + <widget class="LabelsComboBox" name="lstLabel"> <property name="sizePolicy"> <sizepolicy hsizetype="MinimumExpanding" vsizetype="Fixed"> <horstretch>0</horstretch> @@ -103,7 +103,7 @@ </property> <layout class="QGridLayout" name="gridLayout_2"> <item row="0" column="0"> - <layout class="QVBoxLayout" name="imagesBoxLayout"/> + <widget class="ImageReferencesCheckListWidget" name="lstImageReferences"/> </item> </layout> </widget> @@ -155,6 +155,23 @@ </item> </layout> </widget> + <customwidgets> + <customwidget> + <class>LabelsComboBox</class> + <extends>QComboBox</extends> + <header>ui/widgets/LabelsComboBox.h</header> + </customwidget> + <customwidget> + <class>OutputBlocksComboBox</class> + <extends>QComboBox</extends> + <header>ui/widgets/OutputBlocksComboBox.h</header> + </customwidget> + <customwidget> + <class>ImageReferencesCheckListWidget</class> + <extends>QListWidget</extends> + <header>ui/widgets/ImageReferencesCheckListWidget.h</header> + </customwidget> + </customwidgets> <tabstops> <tabstop>lstLabel</tabstop> <tabstop>lstOutputBlock</tabstop> diff --git a/Grinder/ui/dlg/TextViewerDialog.ui b/Grinder/ui/dlg/TextViewerDialog.ui index 62c6ced4eeea655fd699439036669ef14283b53f..aff1912334825ce47ea41021a74923cda858ff16 100644 --- a/Grinder/ui/dlg/TextViewerDialog.ui +++ b/Grinder/ui/dlg/TextViewerDialog.ui @@ -44,6 +44,9 @@ <pointsize>10</pointsize> </font> </property> + <property name="lineWrapMode"> + <enum>QPlainTextEdit::NoWrap</enum> + </property> <property name="readOnly"> <bool>true</bool> </property> diff --git a/Grinder/ui/task/tasks/BaristaInferenceTaskWidget.cpp b/Grinder/ui/task/tasks/BaristaInferenceTaskWidget.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3d84ae2b5e85580b3648483d20b52bbe1df173de --- /dev/null +++ b/Grinder/ui/task/tasks/BaristaInferenceTaskWidget.cpp @@ -0,0 +1,108 @@ +/****************************************************************************** + * File: BaristaInferenceTaskWidget.cpp + * Date: 07.12.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "BaristaInferenceTaskWidget.h" +#include "ui_BaristaInferenceTaskWidget.h" +#include "core/GrinderApplication.h" +#include "task/tasks/BaristaInferenceTask.h" + +BaristaInferenceTaskWidget::BaristaInferenceTaskWidget(BaristaInferenceTask* task, bool newTask, QWidget* parent) : ConfigureTaskWidget(task, newTask, parent), + ui{new Ui::BaristaInferenceTaskWidget} +{ + setupUi(); + + fillLabels(grinder()->project().labels()); + fillImageReferences(grinder()->project().imageReferences()); +} + +BaristaInferenceTaskWidget::~BaristaInferenceTaskWidget() +{ + delete ui; +} + +void BaristaInferenceTaskWidget::verifySettings() +{ + if (ui->txtLibraryPath->text().isEmpty()) + showError("Please enter a library path.", ui->txtLibraryPath); + + if (ui->txtNetworkPath->text().isEmpty()) + showError("Please enter a network path.", ui->txtNetworkPath); + + if (ui->txtModelPath->text().isEmpty()) + showError("Please enter a model path.", ui->txtModelPath); + + if (getImageReferences().empty()) + showError("Please select at least one image.", ui->lstImageReferences); +} + +void BaristaInferenceTaskWidget::applySettings(bool save) +{ + if (save) + { + _task->setBaristaPort(ui->txtWorkerPort->value()); + + _task->setLibraryPath(ui->txtLibraryPath->text()); + _task->setNetworkPath(ui->txtNetworkPath->text()); + _task->setModelPath(ui->txtModelPath->text()); + + _task->setLabel(getLabel()); + _task->setOutputBlock(getOutputBlock()); + _task->setImageReferences(getImageReferences()); + } + else + { + ui->txtWorkerPort->setValue(_task->getBaristaPort()); + + ui->txtLibraryPath->setText(_task->getLibraryPath()); + ui->txtNetworkPath->setText(_task->getNetworkPath()); + ui->txtModelPath->setText(_task->getModelPath()); + + ui->lstLabel->selectLabel(_task->getLabel()); + ui->lstOutputBlock->selectOutputBlock(_task->getOutputBlock()); + ui->lstImageReferences->selectImageReferences(_task->getImageReferences()); + } +} + +void BaristaInferenceTaskWidget::on_lstLabel_currentIndexChanged(int index) +{ + Q_UNUSED(index); + fillOutputBlocks(ui->lstLabel->getSelectedLabel()); +} + +void BaristaInferenceTaskWidget::setupUi() +{ + ui->setupUi(this); +} + +void BaristaInferenceTaskWidget::fillLabels(const LabelVector& labels) +{ + ui->lstLabel->populate(labels); +} + +void BaristaInferenceTaskWidget::fillOutputBlocks(const Label* label) +{ + ui->lstOutputBlock->populate(label, "None (use unaltered images)"); +} + +void BaristaInferenceTaskWidget::fillImageReferences(const ImageReferenceVector& imageRefs) +{ + ui->lstImageReferences->populate(imageRefs); +} + +Label* BaristaInferenceTaskWidget::getLabel() const +{ + return ui->lstLabel->getSelectedLabel(); +} + +Block* BaristaInferenceTaskWidget::getOutputBlock() const +{ + return ui->lstOutputBlock->getSelectedOutputBlock(); +} + +std::vector<const ImageReference*> BaristaInferenceTaskWidget::getImageReferences() const +{ + return ui->lstImageReferences->getSelectedImageReferences(); +} diff --git a/Grinder/ui/task/tasks/BaristaInferenceTaskWidget.h b/Grinder/ui/task/tasks/BaristaInferenceTaskWidget.h new file mode 100644 index 0000000000000000000000000000000000000000..0277e8e1b35090f06d6aa051d0dac9de5469c7d2 --- /dev/null +++ b/Grinder/ui/task/tasks/BaristaInferenceTaskWidget.h @@ -0,0 +1,55 @@ +/****************************************************************************** + * File: BaristaInferenceTaskWidget.h + * Date: 07.12.2018 + *****************************************************************************/ + +#ifndef BARISTAINFERENCETASKWIDGET_H +#define BARISTAINFERENCETASKWIDGET_H + +#include "ui/task/ConfigureTaskWidget.h" + +namespace Ui +{ + class BaristaInferenceTaskWidget; +} + +namespace grndr +{ + class BaristaInferenceTask; + class Label; + class LabelVector; + class ImageReference; + class ImageReferenceVector; + class Block; + + class BaristaInferenceTaskWidget : public ConfigureTaskWidget<BaristaInferenceTask> + { + Q_OBJECT + + public: + explicit BaristaInferenceTaskWidget(BaristaInferenceTask* task, bool newTask, QWidget *parent = nullptr); + virtual ~BaristaInferenceTaskWidget(); + + public: + virtual void verifySettings() override; + virtual void applySettings(bool save) override; + + private slots: + void on_lstLabel_currentIndexChanged(int index); + + private: + Ui::BaristaInferenceTaskWidget *ui; + void setupUi(); + + private: + void fillLabels(const LabelVector& labels); + void fillOutputBlocks(const Label* label); + void fillImageReferences(const ImageReferenceVector& imageRefs); + + Label* getLabel() const; + Block* getOutputBlock() const; + std::vector<const ImageReference*> getImageReferences() const; + }; +} + +#endif diff --git a/Grinder/ui/task/tasks/BaristaInferenceTaskWidget.ui b/Grinder/ui/task/tasks/BaristaInferenceTaskWidget.ui new file mode 100644 index 0000000000000000000000000000000000000000..f3d8a08cca5ee86cf3400c230896cc045f8570ae --- /dev/null +++ b/Grinder/ui/task/tasks/BaristaInferenceTaskWidget.ui @@ -0,0 +1,232 @@ +<?xml version="1.0" encoding="UTF-8"?> +<ui version="4.0"> + <class>BaristaInferenceTaskWidget</class> + <widget class="QWidget" name="BaristaInferenceTaskWidget"> + <property name="geometry"> + <rect> + <x>0</x> + <y>0</y> + <width>344</width> + <height>417</height> + </rect> + </property> + <property name="windowTitle"> + <string>Form</string> + </property> + <layout class="QGridLayout" name="gridLayout"> + <property name="leftMargin"> + <number>0</number> + </property> + <property name="topMargin"> + <number>0</number> + </property> + <property name="rightMargin"> + <number>0</number> + </property> + <property name="bottomMargin"> + <number>0</number> + </property> + <item row="0" column="1"> + <widget class="QSpinBox" name="txtWorkerPort"> + <property name="minimum"> + <number>1024</number> + </property> + <property name="maximum"> + <number>65536</number> + </property> + </widget> + </item> + <item row="0" column="0"> + <widget class="QLabel" name="label"> + <property name="text"> + <string>&Worker port:</string> + </property> + <property name="buddy"> + <cstring>txtWorkerPort</cstring> + </property> + </widget> + </item> + <item row="0" column="2"> + <spacer name="horizontalSpacer"> + <property name="orientation"> + <enum>Qt::Horizontal</enum> + </property> + <property name="sizeHint" stdset="0"> + <size> + <width>265</width> + <height>20</height> + </size> + </property> + </spacer> + </item> + <item row="1" column="0"> + <spacer name="verticalSpacer_2"> + <property name="orientation"> + <enum>Qt::Vertical</enum> + </property> + <property name="sizeType"> + <enum>QSizePolicy::Fixed</enum> + </property> + <property name="sizeHint" stdset="0"> + <size> + <width>20</width> + <height>6</height> + </size> + </property> + </spacer> + </item> + <item row="4" column="0"> + <widget class="QLabel" name="label_4"> + <property name="text"> + <string>&Model path:</string> + </property> + <property name="buddy"> + <cstring>txtModelPath</cstring> + </property> + </widget> + </item> + <item row="3" column="0"> + <widget class="QLabel" name="label_3"> + <property name="text"> + <string>&Network path:</string> + </property> + <property name="buddy"> + <cstring>txtNetworkPath</cstring> + </property> + </widget> + </item> + <item row="3" column="1" colspan="2"> + <widget class="QLineEdit" name="txtNetworkPath"/> + </item> + <item row="4" column="1" colspan="2"> + <widget class="QLineEdit" name="txtModelPath"/> + </item> + <item row="2" column="1" colspan="2"> + <widget class="QLineEdit" name="txtLibraryPath"/> + </item> + <item row="5" column="1"> + <spacer name="verticalSpacer_3"> + <property name="orientation"> + <enum>Qt::Vertical</enum> + </property> + <property name="sizeType"> + <enum>QSizePolicy::Fixed</enum> + </property> + <property name="sizeHint" stdset="0"> + <size> + <width>20</width> + <height>6</height> + </size> + </property> + </spacer> + </item> + <item row="6" column="0"> + <widget class="QLabel" name="label_6"> + <property name="text"> + <string>Label:</string> + </property> + <property name="buddy"> + <cstring>lstLabel</cstring> + </property> + </widget> + </item> + <item row="6" column="1" colspan="2"> + <widget class="LabelsComboBox" name="lstLabel"> + <property name="sizePolicy"> + <sizepolicy hsizetype="MinimumExpanding" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="minimumSize"> + <size> + <width>150</width> + <height>0</height> + </size> + </property> + </widget> + </item> + <item row="2" column="0"> + <widget class="QLabel" name="label_2"> + <property name="text"> + <string>Library &path:</string> + </property> + <property name="buddy"> + <cstring>txtLibraryPath</cstring> + </property> + </widget> + </item> + <item row="10" column="0" colspan="3"> + <widget class="ImageReferencesCheckListWidget" name="lstImageReferences"/> + </item> + <item row="7" column="1" colspan="2"> + <widget class="OutputBlocksComboBox" name="lstOutputBlock"> + <property name="sizePolicy"> + <sizepolicy hsizetype="MinimumExpanding" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="minimumSize"> + <size> + <width>150</width> + <height>0</height> + </size> + </property> + </widget> + </item> + <item row="9" column="0"> + <widget class="QLabel" name="label_7"> + <property name="text"> + <string>Images:</string> + </property> + </widget> + </item> + <item row="7" column="0"> + <widget class="QLabel" name="label_5"> + <property name="text"> + <string>Output block:</string> + </property> + <property name="buddy"> + <cstring>lstOutputBlock</cstring> + </property> + </widget> + </item> + <item row="8" column="1"> + <spacer name="verticalSpacer_5"> + <property name="orientation"> + <enum>Qt::Vertical</enum> + </property> + <property name="sizeType"> + <enum>QSizePolicy::Fixed</enum> + </property> + <property name="sizeHint" stdset="0"> + <size> + <width>20</width> + <height>6</height> + </size> + </property> + </spacer> + </item> + </layout> + </widget> + <customwidgets> + <customwidget> + <class>LabelsComboBox</class> + <extends>QComboBox</extends> + <header>ui/widgets/LabelsComboBox.h</header> + </customwidget> + <customwidget> + <class>OutputBlocksComboBox</class> + <extends>QComboBox</extends> + <header>ui/widgets/OutputBlocksComboBox.h</header> + </customwidget> + <customwidget> + <class>ImageReferencesCheckListWidget</class> + <extends>QListWidget</extends> + <header>ui/widgets/ImageReferencesCheckListWidget.h</header> + </customwidget> + </customwidgets> + <resources/> + <connections/> +</ui> diff --git a/Grinder/ui/task/tasks/BaristaTrainingTaskWidget.cpp b/Grinder/ui/task/tasks/BaristaTrainingTaskWidget.cpp index 5adc76b66bcc6bc3789f4e2cf5c33dde5942a26e..91c7e4f668ba0cb74b9e8ced23a5cc4e6ee67c8f 100644 --- a/Grinder/ui/task/tasks/BaristaTrainingTaskWidget.cpp +++ b/Grinder/ui/task/tasks/BaristaTrainingTaskWidget.cpp @@ -31,11 +31,6 @@ void BaristaTrainingTaskWidget::verifySettings() showError("Please enter a solver path.", ui->txtSolverPath); } -void BaristaTrainingTaskWidget::setupUi() -{ - ui->setupUi(this); -} - void BaristaTrainingTaskWidget::applySettings(bool save) { if (save) @@ -55,3 +50,8 @@ void BaristaTrainingTaskWidget::applySettings(bool save) ui->txtSolverPath->setText(_task->getSolverPath()); } } + +void BaristaTrainingTaskWidget::setupUi() +{ + ui->setupUi(this); +} diff --git a/Grinder/ui/task/tasks/BaristaTrainingTaskWidget.ui b/Grinder/ui/task/tasks/BaristaTrainingTaskWidget.ui index 6b0700f212c84c1314fd53289806af8f8b3c34bc..774c899750ea901b9fa4b11435557fdc6ce5a824 100644 --- a/Grinder/ui/task/tasks/BaristaTrainingTaskWidget.ui +++ b/Grinder/ui/task/tasks/BaristaTrainingTaskWidget.ui @@ -26,6 +26,26 @@ <property name="bottomMargin"> <number>0</number> </property> + <item row="0" column="0"> + <widget class="QLabel" name="label"> + <property name="text"> + <string>&Worker port:</string> + </property> + <property name="buddy"> + <cstring>txtWorkerPort</cstring> + </property> + </widget> + </item> + <item row="4" column="0"> + <widget class="QLabel" name="label_4"> + <property name="text"> + <string>&Solver path:</string> + </property> + <property name="buddy"> + <cstring>txtSolverPath</cstring> + </property> + </widget> + </item> <item row="2" column="0"> <widget class="QLabel" name="label_2"> <property name="text"> @@ -36,6 +56,16 @@ </property> </widget> </item> + <item row="3" column="0"> + <widget class="QLabel" name="label_3"> + <property name="text"> + <string>Sessio&n path:</string> + </property> + <property name="buddy"> + <cstring>txtSessionPath</cstring> + </property> + </widget> + </item> <item row="2" column="1" colspan="2"> <widget class="QLineEdit" name="txtLibraryPath"/> </item> @@ -49,45 +79,41 @@ </property> </widget> </item> - <item row="0" column="2"> - <spacer name="horizontalSpacer"> + <item row="1" column="1"> + <spacer name="verticalSpacer_4"> <property name="orientation"> - <enum>Qt::Horizontal</enum> + <enum>Qt::Vertical</enum> + </property> + <property name="sizeType"> + <enum>QSizePolicy::Fixed</enum> </property> <property name="sizeHint" stdset="0"> <size> - <width>40</width> - <height>20</height> + <width>20</width> + <height>6</height> </size> </property> </spacer> </item> - <item row="0" column="0"> - <widget class="QLabel" name="label"> - <property name="text"> - <string>&Worker port:</string> - </property> - <property name="buddy"> - <cstring>txtWorkerPort</cstring> - </property> - </widget> + <item row="4" column="1" colspan="2"> + <widget class="QLineEdit" name="txtSolverPath"/> </item> - <item row="1" column="0"> - <spacer name="verticalSpacer_2"> + <item row="0" column="2"> + <spacer name="horizontalSpacer"> <property name="orientation"> - <enum>Qt::Vertical</enum> - </property> - <property name="sizeType"> - <enum>QSizePolicy::Fixed</enum> + <enum>Qt::Horizontal</enum> </property> <property name="sizeHint" stdset="0"> <size> - <width>20</width> - <height>6</height> + <width>40</width> + <height>20</height> </size> </property> </spacer> </item> + <item row="3" column="1" colspan="2"> + <widget class="QLineEdit" name="txtSessionPath"/> + </item> <item row="5" column="0"> <spacer name="verticalSpacer"> <property name="orientation"> @@ -101,32 +127,6 @@ </property> </spacer> </item> - <item row="3" column="1" colspan="2"> - <widget class="QLineEdit" name="txtSessionPath"/> - </item> - <item row="3" column="0"> - <widget class="QLabel" name="label_3"> - <property name="text"> - <string>Sessio&n path:</string> - </property> - <property name="buddy"> - <cstring>txtSessionPath</cstring> - </property> - </widget> - </item> - <item row="4" column="0"> - <widget class="QLabel" name="label_4"> - <property name="text"> - <string>&Solver path:</string> - </property> - <property name="buddy"> - <cstring>txtSolverPath</cstring> - </property> - </widget> - </item> - <item row="4" column="1" colspan="2"> - <widget class="QLineEdit" name="txtSolverPath"/> - </item> </layout> </widget> <tabstops> diff --git a/Grinder/ui/widgets/ImageReferencesCheckListWidget.cpp b/Grinder/ui/widgets/ImageReferencesCheckListWidget.cpp new file mode 100644 index 0000000000000000000000000000000000000000..97158caeb66016de3861f3a82957bdecc2a145e8 --- /dev/null +++ b/Grinder/ui/widgets/ImageReferencesCheckListWidget.cpp @@ -0,0 +1,23 @@ +/****************************************************************************** + * File: ImageReferencesCheckListWidget.cpp + * Date: 13.12.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "ImageReferencesCheckListWidget.h" + +void ImageReferencesCheckListWidget::populate(const ImageReferenceVector& imageRefs) +{ + populateList(imageRefs); + setCheckedObjects(imageRefs.toVector()); +} + +std::vector<const ImageReference*> ImageReferencesCheckListWidget::getSelectedImageReferences() const +{ + return getCheckedObjects<const ImageReference>(); +} + +void ImageReferencesCheckListWidget::selectImageReferences(const std::vector<const ImageReference*>& imageRefs) +{ + setCheckedObjects(imageRefs); +} diff --git a/Grinder/ui/widgets/ImageReferencesCheckListWidget.h b/Grinder/ui/widgets/ImageReferencesCheckListWidget.h new file mode 100644 index 0000000000000000000000000000000000000000..ace23b4a361224b492e71cbc651571b0e6b44987 --- /dev/null +++ b/Grinder/ui/widgets/ImageReferencesCheckListWidget.h @@ -0,0 +1,30 @@ +/****************************************************************************** + * File: ImageReferencesCheckListWidget.h + * Date: 13.12.2018 + *****************************************************************************/ + +#ifndef IMAGEREFERENCESCHECKLISTWIDGET_H +#define IMAGEREFERENCESCHECKLISTWIDGET_H + +#include "ui/widgets/CheckListWidget.h" +#include "ui/mainwnd/ImageReferencesListItem.h" +#include "project/ImageReferenceVector.h" + +namespace grndr +{ + class ImageReferencesCheckListWidget : public CheckListWidget<ImageReference, ImageReferencesListItem> + { + Q_OBJECT + + public: + using CheckListWidget::CheckListWidget; + + public: + void populate(const ImageReferenceVector& imageRefs); + + std::vector<const ImageReference*> getSelectedImageReferences() const; + void selectImageReferences(const std::vector<const ImageReference*>& imageRefs); + }; +} + +#endif diff --git a/Grinder/ui/widgets/LabelsComboBox.cpp b/Grinder/ui/widgets/LabelsComboBox.cpp new file mode 100644 index 0000000000000000000000000000000000000000..164c2075249137256c1b0a2213b664bb3c478b8d --- /dev/null +++ b/Grinder/ui/widgets/LabelsComboBox.cpp @@ -0,0 +1,50 @@ +/****************************************************************************** + * File: LabelsComboBox.cpp + * Date: 13.12.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "LabelsComboBox.h" +#include "project/LabelVector.h" +#include "core/GrinderApplication.h" + +Q_DECLARE_METATYPE(Label*); + +void LabelsComboBox::populate(const LabelVector& labels) +{ + int index = 0; + + // Add all labels + for (const auto& label : labels) + { + addItem(label->getName(), QVariant::fromValue(label.get())); + + if (label.get() == grinder()->projectController().activeLabel()) + index = count() - 1; + } + + if (count() == 0) + { + addItem("No labels to display", QVariant::fromValue<Label*>(nullptr)); + setEnabled(false); + } + + setCurrentIndex(index); +} + +Label* LabelsComboBox::getSelectedLabel() const +{ + return currentData().value<Label*>(); +} + +void LabelsComboBox::selectLabel(const Label* label) +{ + for (int i = 0; i < count(); ++i) + { + if (itemData(i).value<Label*>() == label) + { + setCurrentIndex(i); + break; + } + } +} diff --git a/Grinder/ui/widgets/LabelsComboBox.h b/Grinder/ui/widgets/LabelsComboBox.h new file mode 100644 index 0000000000000000000000000000000000000000..011c7f32d4690d6309c48c14be811b8d3a820f08 --- /dev/null +++ b/Grinder/ui/widgets/LabelsComboBox.h @@ -0,0 +1,31 @@ +/****************************************************************************** + * File: LabelsComboBox.h + * Date: 13.12.2018 + *****************************************************************************/ + +#ifndef LABELSCOMBOBOX_H +#define LABELSCOMBOBOX_H + +#include <QComboBox> + +namespace grndr +{ + class Label; + class LabelVector; + + class LabelsComboBox : public QComboBox + { + Q_OBJECT + + public: + using QComboBox::QComboBox; + + public: + void populate(const LabelVector& labels); + + Label* getSelectedLabel() const; + void selectLabel(const Label* label); + }; +} + +#endif diff --git a/Grinder/ui/widgets/OutputBlocksComboBox.cpp b/Grinder/ui/widgets/OutputBlocksComboBox.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eb1af1061d99506a507342e2d9b5f0579bdd0e83 --- /dev/null +++ b/Grinder/ui/widgets/OutputBlocksComboBox.cpp @@ -0,0 +1,53 @@ +/****************************************************************************** + * File: OutputBlocksComboBox.cpp + * Date: 13.12.2018 + *****************************************************************************/ + +#include "Grinder.h" +#include "OutputBlocksComboBox.h" +#include "project/Label.h" + +Q_DECLARE_METATYPE(Block*); + +void OutputBlocksComboBox::populate(const Label* label, QString noneItemText) +{ + clear(); + + // The first item always represents the raw input images + addItem(noneItemText, QVariant::fromValue<Block*>(nullptr)); + + auto boldFont = font(); + boldFont.setBold(true); + setItemData(0, boldFont, Qt::FontRole); + + insertSeparator(1); + + if (label) + { + // Add all output blocks + for (const auto& block : label->pipeline()->blocks()) + { + if (block->getType() == BlockType::Output) + addItem(block->getName(), QVariant::fromValue<Block*>(block.get())); + } + } + + setCurrentIndex(0); +} + +Block* OutputBlocksComboBox::getSelectedOutputBlock() const +{ + return currentData().value<Block*>(); +} + +void OutputBlocksComboBox::selectOutputBlock(const Block* block) +{ + for (int i = 0; i < count(); ++i) + { + if (itemData(i).value<Block*>() == block) + { + setCurrentIndex(i); + break; + } + } +} diff --git a/Grinder/ui/widgets/OutputBlocksComboBox.h b/Grinder/ui/widgets/OutputBlocksComboBox.h new file mode 100644 index 0000000000000000000000000000000000000000..1b5e17162e27608cdc5e23c26c9c3f9dde82d506 --- /dev/null +++ b/Grinder/ui/widgets/OutputBlocksComboBox.h @@ -0,0 +1,31 @@ +/****************************************************************************** + * File: OutputBlocksComboBox.h + * Date: 13.12.2018 + *****************************************************************************/ + +#ifndef OUTPUTBLOCKSCOMBOBOX_H +#define OUTPUTBLOCKSCOMBOBOX_H + +#include <QComboBox> + +namespace grndr +{ + class Label; + class Block; + + class OutputBlocksComboBox : public QComboBox + { + Q_OBJECT + + public: + using QComboBox::QComboBox; + + public: + void populate(const Label* label, QString noneItemText = "None"); + + Block* getSelectedOutputBlock() const; + void selectOutputBlock(const Block* block); + }; +} + +#endif