fixed a bug where barista crashes if the remote host has no caffe-version

fixed a bug where barista crashes if a remote session with an unknown layer is imported
implemented a warning screen to notify the user that a remote session cannot be imported if it has layers unknown to the current caffe-version.
implemented the 'restart barista dialog' if the current default version is deleted
parent 9b40d54e
......@@ -8,6 +8,7 @@ from datetime import datetime
from PyQt5.Qt import QObject
from PyQt5.QtCore import pyqtSignal
from PyQt5.QtWidgets import QMessageBox
from backend.barista.constraints.permanent.project import ensureProjectDataConstraints
from backend.barista.session.session import Session
from backend.barista.session.client_session import ClientSession
......@@ -491,6 +492,23 @@ class Project(QObject, LogCaller):
def loadRemoteSession(self, remote, uid):
sid = self.getNextSessionId()
session = ClientSession(self, remote, uid, sid)
availableTypes = info.CaffeMetaInformation().availableLayerTypes()
unknownLayerTypes = []
for layerID in session.state_dictionary["network"]["layers"]:
type = session.state_dictionary["network"]["layers"][layerID]["parameters"]["type"]
if not type in availableTypes and not type in unknownLayerTypes:
unknownLayerTypes.append(type)
if len(unknownLayerTypes) > 0:
msg = "Cannot load session. The selected session contains layers unknown to the current caffe-version.\n\nUnknown layers:"
for type in unknownLayerTypes:
msg += "\n" + type
msgBox = QMessageBox(QMessageBox.Warning, "Warning", msg)
msgBox.addButton("Ok", QMessageBox.NoRole)
msgBox.exec_()
return None
self.__sessions[sid] = session
self.newSession.emit(sid)
return sid
......
......@@ -372,9 +372,11 @@ class ServerTransaction(Transaction):
def _getDefaultCaffeVersion(self):
msg = self.asyncRead()
defaultVersionName = caffeVersions.getDefaultVersion().getName()
msg["defaultVersionName"] = defaultVersionName
msg["status"] = True
if caffeVersions.getDefaultVersion():
msg["defaultVersionName"] = caffeVersions.getDefaultVersion().getName()
msg["status"] = True
else:
msg["status"] = False
self.send(msg)
def _getFileHash(self):
......
......@@ -15,7 +15,6 @@ from PyQt5.QtWidgets import (
from backend.barista.utils.logger import Log
from gui.caffepath_dialog import CaffepathDialog
from gui.caffe_version_manager.link_to_host_dialog import LinkToHostDialog
import backend.barista.caffe_versions as caffeVersions
from gui.caffe_version_manager.caffe_version_widget import CaffeVersionWidget
from backend.networking.net_util import sendMsgToHost
......@@ -138,7 +137,8 @@ class CaffeVersionManager(QDialog):
msg = {"key": Protocol.GETDEFAULTCAFFEVERSION}
reply = sendMsgToHost(host.host, host.port, msg)
if reply:
remoteCurrent = reply["defaultVersionName"]
if reply["status"]:
remoteCurrent = reply["defaultVersionName"]
msg = {"key": Protocol.GETCAFFEVERSIONS}
reply = sendMsgToHost(host.host, host.port, msg)
......
......@@ -86,17 +86,13 @@ class CaffeVersionWidget(QWidget):
def _onRemoveVersion(self):
"""Remove this version from the caffe_version_manager"""
self.versionManager._onRemoveVersion(self.caffe_version.getName(), self.host)
self.restartWarning()
def _onSetCurrent(self):
"""Sets this version as the current projects/remote hosts caffe version"""
if self.host == None:
self.versionManager.project.changeProjectCaffeVersion(self.caffe_version.getName())
msgBox = QMessageBox(QMessageBox.Warning, "Warning", "Please restart Barista client for changes to apply, otherwise Barista may be unstable!")
msgBox.addButton("Ok", QMessageBox.NoRole)
msgBox.addButton("Restart now", QMessageBox.YesRole)
if msgBox.exec_() == 1:
self.versionManager.actions.restart()
self.restartWarning()
else:
msg = {"key": Protocol.SETCURRENTCAFFEVERSION, "versionname": self.caffe_version.getName()}
sendMsgToHost(self.host.host, self.host.port, msg)
......@@ -107,4 +103,11 @@ class CaffeVersionWidget(QWidget):
if msgBox.exec_() == 1:
msg = {"key": Protocol.RESTART, "pid": self.versionManager.project.getProjectId()}
sendMsgToHost(self.host.host, self.host.port, msg)
self.versionManager.updateList()
\ No newline at end of file
self.versionManager.updateList()
def restartWarning(self):
msgBox = QMessageBox(QMessageBox.Warning, "Warning", "Please restart Barista client for changes to apply, otherwise Barista may be unstable!")
msgBox.addButton("Ok", QMessageBox.NoRole)
msgBox.addButton("Restart now", QMessageBox.YesRole)
if msgBox.exec_() == 1:
self.versionManager.actions.restart()
\ No newline at end of file
from PyQt5.QtGui import QIcon, QFont, QPalette
from PyQt5 import QtCore
import backend.barista.caffe_versions as caffeVersions
from PyQt5.QtWidgets import (
QWidget,
QGridLayout,
QLabel,
QPushButton,
qApp,
QStyle,
QListWidget,
QListWidgetItem,
QAbstractItemView
)
class CaffeVersionWidget(QWidget):
"""This widget class represents a remote caffe version"""
def __init__(self, version, parent, hostdialog, isSelected):
super(CaffeVersionWidget, self).__init__(parent)
self.caffe_version = version
self.hostdialog = hostdialog
self.boldFont = QFont()
self.boldFont.setBold(True)
"""Icons"""
self.ico_delete = QIcon('resources/trash.png')
self.ico_select = QIcon('resources/select.png')
"""Layout"""
self.layout = QGridLayout(self)
"""General buttons"""
self.btnDelete = QPushButton(self.ico_delete, "")
self.btnDelete.setToolTip("Delete Version")
self.btnDelete.setFixedSize(30,30)
self.layout.addWidget(self.btnDelete, 0, 0, 1, 1)
self.btnSelect = QPushButton(self.ico_select, "")
self.btnSelect.setToolTip("Select Version")
self.btnSelect.setFixedSize(30, 30)
self.layout.addWidget(self.btnSelect, 0, 1, 1, 1)
"""Name label"""
self.lblName = QLabel(self.caffe_version.getName())
self.lblName.setFont(self.boldFont)
self.lblName.setAlignment(QtCore.Qt.AlignCenter)
self.layout.addWidget(self.lblName, 0, 2, 1, -1)
if isSelected:
"""show details of the current version if it is selected"""
self.lblDetailsTitle = QLabel("Details:")
self.lblDetailsTitle.setFont(self.boldFont)
self.layout.addWidget(self.lblDetailsTitle, 1, 0, 1, -1)
self.layout2 = QGridLayout()
#btnEdit = QPushButton(self.ico_edit, "")
#btnEdit.setFixedSize(30,30)
#self.layout2.addWidget(btnEdit, 0, 0, 1, 1)
self.layout2.addWidget(QLabel("Root:\t"+self.caffe_version.getRootpath()), 0, 1, 1, -1)
#btnEdit = QPushButton(self.ico_edit, "")
#btnEdit.setFixedSize(30,30)
#self.layout2.addWidget(btnEdit, 1, 0, 1, 1)
self.layout2.addWidget(QLabel("Binary:\t"+self.caffe_version.getBinarypath()), 1, 1, 1, -1)
#btnEdit = QPushButton(self.ico_edit, "")
#btnEdit.setFixedSize(30,30)
#self.layout2.addWidget(btnEdit, 2, 0, 1, 1)
self.layout2.addWidget(QLabel("Python:\t"+self.caffe_version.getPythonpath()), 2, 1, 1, -1)
#btnEdit = QPushButton(self.ico_edit, "")
#btnEdit.setFixedSize(30,30)
#self.layout2.addWidget(btnEdit, 3, 0, 1, 1)
self.layout2.addWidget(QLabel("Proto:\t"+self.caffe_version.getProtopath()), 3, 1, 1, -1)
self.layout.addLayout(self.layout2, 2, 0, 1, -1)
self.btnDelete.clicked.connect(self._onRemoveVersion)
self.btnSelect.clicked.connect(self._onSelectVersion)
def _onRemoveVersion(self):
"""Remove this version from the remote host"""
self.hostdialog._onRemoveVersion(self.caffe_version)
def _onSelectVersion(self):
"""Link this version to local caffe version"""
self.hostdialog._onLinkVersions(self.caffe_version)
def getVersionName(self):
"""Return the name of the caffe version represented in this widget"""
return self.caffe_version.getName()
from PyQt5.QtWidgets import (
QDialog,
QListWidget,
QListWidgetItem,
QAbstractItemView,
QGridLayout,
QLabel,
QPushButton,
QMessageBox,
QWidget
)
from PyQt5.QtGui import QFont
from gui.host_manager.host_manager import Host
from gui.caffepath_dialog import CaffepathDialog
import backend.barista.caffe_versions as caffeVersions
from backend.networking.net_util import sendMsgToHost
from gui.caffe_version_manager.caffe_version_widget_remote import CaffeVersionWidget
from backend.barista.session.session import State
from backend.networking.protocol import Protocol
import backend.barista.hash as Hash
import os
class LinkToHostDialog(QDialog):
"""This Dialog gives the user the opportunity to link a local caffe version to caffe version on a remote host.
Also new caffe versions can be added to hosts"""
def __init__(self, hostManager, version, caffeVersionManager, parent=None, remote=False):
super(LinkToHostDialog, self).__init__(parent)
self.resize(800, 300)
self.hostManager = hostManager
self.caffeVersionManager = caffeVersionManager
self.hosts = {}
self.selectedHost = None
self.version = version
self.selectedVersion = None
self.setWindowTitle("Link " + self.version.getName() + " to a remote host")
self.layout = QGridLayout(self)
self.lblHosts = QLabel("Available hosts")
self.layout.addWidget(self.lblHosts, 0, 0, 1, 2)
self.lstHosts = QListWidget(self)
self.lstHosts.setSelectionMode(QAbstractItemView.SingleSelection)
self.layout.addWidget(self.lstHosts, 1, 0, -1, 2)
self.lblVersions = QLabel("Caffe versions on host")
self.layout.addWidget(self.lblVersions, 0, 2, 1, 3)
self.lstVersions = QListWidget(self)
self.lstVersions.setSelectionMode(QAbstractItemView.SingleSelection)
self.lstVersions.setStyleSheet(self.lstVersions.styleSheet() + "QListWidget::item { border-bottom: 1px solid lightgray; }" )
self.layout.addWidget(self.lstVersions, 1, 2, 1, 3)
self.btnAddVersion = QPushButton("Add new version")
self.btnAddVersion.setEnabled(False)
self.layout.addWidget(self.btnAddVersion, 2, 2, 1, 3)
self.lstHosts.itemSelectionChanged.connect(self._onHostSelectionChanged)
self.lstVersions.itemSelectionChanged.connect(self._onVersionSelectionChanged)
self.btnAddVersion.clicked.connect(self._onAddVersion)
self.updateHosts()
self.displayHosts()
def updateHosts(self):
"""Update the list of available hosts"""
for host in self.hostManager.getActiveHostList():
self.hosts[host[0]] = self.hostManager.getHostById(host[0])
def displayHosts(self):
"""Displays the list of available hosts"""
self.lstHosts.clear()
selectedItem = None
for host in self.hosts:
widget = HostWidget(self.hosts[host])
item = QListWidgetItem()
item.setSizeHint(widget.sizeHint())
self.lstHosts.addItem(item)
self.lstHosts.setItemWidget(item, widget)
if self.selectedHost is not None and host == self.selectedHost.id:
selectedItem = item
if self.selectedHost is not None:
self.lstHosts.setCurrentItem(selectedItem)
def _onHostSelectionChanged(self):
"""Show the caffe versions of the selected host in seperate list"""
self.lstVersions.clear()
if len(self.lstHosts.selectedItems()) > 0:
self.selectedHost = self.lstHosts.itemWidget(self.lstHosts.currentItem()).host
self.btnAddVersion.setEnabled(True)
self.updateVersionList()
else:
self.btnAddVersion.setEnabled(False)
def updateVersionList(self):
self.lstVersions.clear()
host = self.lstHosts.itemWidget(self.lstHosts.currentItem()).host
msg = {"key": Protocol.GETCAFFEVERSIONS}
reply = sendMsgToHost(host.host, host.port, msg)
if reply:
versions = reply["versions"]
for version in versions:
tempVer = caffeVersions.caffeVersion(version, versions[version]["root"], versions[version]["binary"], versions[version]["python"], versions[version]["proto"])
widget = CaffeVersionWidget(tempVer, self.lstVersions, self, (tempVer.getName() == self.selectedVersion))
item = QListWidgetItem()
item.setSizeHint(widget.sizeHint())
self.lstVersions.addItem(item)
self.lstVersions.setItemWidget(item, widget)
def _onVersionSelectionChanged(self):
"""Enable/disable button to link the caffe versions according to the number of selected entries"""
if len(self.lstVersions.selectedItems()) > 0:
currentName = self.lstVersions.itemWidget(self.lstVersions.currentItem()).getVersionName()
if currentName == self.selectedVersion:
self.selectedVersion = None
else:
self.selectedVersion = currentName
self.updateVersionList()
def _onAddVersion(self):
"""Opens the dialog to add a caffe version for the host and sends the version to the host"""
host = self.lstHosts.itemWidget(self.lstHosts.currentItem()).host
caffedlg = CaffepathDialog("Add a new caffe version to the Host", "Add version", remote=True, host=host)
caffedlg.exec_()
self.updateHosts()
self.displayHosts()
def _onRemoveVersion(self, version):
"""Delete the selected caffe-version from the host"""
host = self.lstHosts.itemWidget(self.lstHosts.currentItem()).host
versionName = version.getName()
msg = {"key": Protocol.REMOVECAFFEVERSION, "versionname": versionName}
sendMsgToHost(host.host, host.port, msg)
self.version.removeHostVersion({"host": host.host, "port": host.port, "versionname": versionName})
self.updateHosts()
self.displayHosts()
def _onLinkVersions(self, version):
"""Check if the two versions are the same, if so link them by saving the host and the selected version in the current lokal version"""
host = self.lstHosts.itemWidget(self.lstHosts.currentItem()).host
identical, differences = self.compareVersions(self.version, version, host)
ret = 0
if not identical:
text = "Could not link remote to local version. They differ in the following files/folders:\n\n"
for dif in differences:
text += dif["local"] + " - " + dif["remote"] + "\n"
msgBox = QMessageBox(QMessageBox.Warning, "Warning", text)
msgBox.addButton("Ok", QMessageBox.NoRole)
msgBox.addButton("Add anyway", QMessageBox.YesRole)
ret = msgBox.exec_()
if identical or ret == 1:
self.version.addHostVersion({"host": host.host, "port": host.port, "versionname": version.getName()})
sessions = self.caffeVersionManager.project.getSessions()
for sessionID in sessions:
session = sessions[sessionID]
try:
sessionHost = session.remote[0]
sessionPort = session.remote[1]
if sessionHost == host.host and sessionPort == host.port and session.getState() == State.NOTLINKED:
session.setState(State.PAUSED)
except:
pass
caffeVersions.saveVersions()
self.close()
def compareVersions(self, localVersion, remoteVersion, host):
"""Hashes the given versions and returns true if they are identical, false otherwise.
The differences will also be returned"""
identical = True
differences = []
"""Hash binaries"""
localBinary = localVersion.getBinarypath()
remoteBinary = remoteVersion.getBinarypath()
localHash = Hash.hashFile(localBinary)
remoteHash = sendMsgToHost(host.host, host.port, {"key": Protocol.GETFILEHASH, "path": remoteBinary})["hash"]
if localHash != remoteHash:
identical = False
differences.append({"local": localBinary, "remote": remoteBinary})
"""Hash protos"""
localProto = localVersion.getProtopath()
remoteProto = remoteVersion.getProtopath()
localHash = Hash.hashFile(localProto)
remoteHash = sendMsgToHost(host.host, host.port, {"key": Protocol.GETFILEHASH, "path": remoteProto})["hash"]
if localHash != remoteHash:
identical = False
differences.append({"local": localProto, "remote": localProto})
"""Hash include paths"""
localIncludePath = os.path.join(localVersion.getRootpath(), "include")
remoteIncludePath = os.path.join(remoteVersion.getRootpath(), "include")
localHash = Hash.hashDir(localIncludePath)
remoteHash = sendMsgToHost(host.host, host.port, {"key": Protocol.GETDIRHASH, "path": remoteIncludePath})["hash"]
if localHash != remoteHash:
identical = False
differences.append({"local": localIncludePath, "remote": remoteIncludePath})
"""Return if versions are identical and the differences"""
return identical, differences
class HostWidget(QWidget):
"""This widget represents a host displayed in the link_to_host_dialog"""
def __init__(self, host, parent=None):
super(HostWidget, self).__init__(parent)
self.host = host
self.layout = QGridLayout(self)
self.lblTitle = QLabel("'" + host.name + "' on port '" + str(host.port) + "'")
font = QFont()
font.setBold(True)
self.lblTitle.setFont(font)
self.layout.addWidget(self.lblTitle, 0, 0, 1, 1)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment