changed pathloader from being a module to a singleton class to prevent wrong...

changed pathloader from being a module to a singleton class to prevent wrong caffe-paths in global variable
parent 9b89dc77
......@@ -99,8 +99,8 @@ def _areParameterValid(allparams, paramdict, prefix=""):
def bareNet(name):
""" Creates a dictionary of a networks with default values where required. """
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
net = proto.NetParameter()
descr = info.ParameterGroupDescriptor(net)
params = descr.parameter().copy()
......@@ -116,8 +116,8 @@ def _bareLayer(layertype, name):
""" Creates a dictionary of the given layertype with the given name
initialized with default values if required.
"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
res = {"type": layertype}
layerParamInst = proto.LayerParameter()
res["parameters"] = _extract_param(layerParamInst, info.CaffeMetaInformation().commonParameters())
......
......@@ -17,8 +17,8 @@ class ParseException(Exception):
def loadSolver(solverstring):
""" Return a dictionary which represent the caffe-solver-prototxt solverstring """
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
solver = proto.SolverParameter()
# Get DESCRIPTION for meta infos
......@@ -66,8 +66,8 @@ def loadNet(netstring):
}
"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
# Load Protoclass for parsing
net = proto.NetParameter()
......@@ -175,8 +175,8 @@ def extractNetFromSolver(solverstring):
This works only, if the solver specifies a network using the "net_param" parameter. A reference to a file using the
"net" parameter can not be handled by this method.
"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
# create empty solver message
solver = proto.SolverParameter()
......@@ -195,8 +195,8 @@ def extractNetFromSolver(solverstring):
def getCaffemodelFromSolverstate(solverstate):
""" Parse the filename of the caffemodel file from the solverstate file.
"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
try:
state = proto.SolverState()
......@@ -220,8 +220,8 @@ def getCaffemodelFromSolverstateHdf5(filename):
def getIterFromSolverstate(solverstate):
""" Parse the iterations from the solverstate file.
"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
try:
state = proto.SolverState()
......
import imp
from backend.barista.utils.settings import applicationQSetting
import backend.barista.caffe_versions as caffeVersions
import os
#todo: move this function to an new class
path = caffeVersions.getDefaultVersion().getPythonpath()
def importCaffe():
import importlib
import sys
global path
sys.path.insert(0, path) # stellt sicher, dass der angegebene Pfad als erstes durchsucht wird
try:
caffe = importlib.import_module("caffe")
sys.path.pop(0)
return caffe
except ImportError as e:
print(e)
exit(1)
def importProto():
import importlib
import sys
global path
sys.path.insert(0, path) # stellt sicher, dass der angegebene Pfad als erstes durchsucht wird
try:
proto = importlib.import_module("caffe.proto.caffe_pb2")
sys.path.pop(0)
return proto
except ImportError as e:
print(e)
exit(1)
def changeCaffeVersion(name):
global path
path = caffeVersions.getVersionByName(name).getPythonpath()
class Singleton(type):
"""This metaclass is used to provide the singleton pattern in a generic way.
See http://stackoverflow.com/a/6798042 for source and further explanation.
TODO outsource this class to use the same pattern in the complete project?
"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class PathLoader:
__metaclass__ = Singleton
def __init__(self):
self.path = caffeVersions.getDefaultVersion().getPythonpath()
def importCaffe(self):
import importlib
import sys
sys.path.insert(0, self.path) # stellt sicher, dass der angegebene Pfad als erstes durchsucht wird
try:
caffe = importlib.import_module("caffe")
sys.path.pop(0)
return caffe
except ImportError as e:
print(e)
exit(1)
def importProto(self):
import importlib
import sys
sys.path.insert(0, self.path) # stellt sicher, dass der angegebene Pfad als erstes durchsucht wird
try:
proto = importlib.import_module("caffe.proto.caffe_pb2")
sys.path.pop(0)
return proto
except ImportError as e:
print(e)
exit(1)
def changeCaffeVersion(self, name):
self.path = caffeVersions.getVersionByName(name).getPythonpath()
......@@ -49,8 +49,8 @@ class CaffeMetaInformation:
See description of self.availableParameterGroupDescriptors().
"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
current_module = sys.modules[proto.__name__]
res = {}
for (el,val) in inspect.getmembers(current_module, inspect.isclass):
......@@ -61,8 +61,8 @@ class CaffeMetaInformation:
"""Generate information about available layer types only once.
self.__initAvailableParameterGroupDescriptors() needs to be called before this method."""
import backend.caffe.path_loader as caffeLoader
caffe = caffeLoader.importCaffe()
from backend.caffe.path_loader import PathLoader
caffe = PathLoader().importCaffe()
layerNameMainParts = list(caffe.layer_type_list())
res = {}
......@@ -111,8 +111,8 @@ class CaffeMetaInformation:
"""Generate information about available solver types only once.
self.__initAvailableParameterGroupDescriptors() needs to be called before this method."""
import backend.caffe.path_loader as caffeLoader
caffe = caffeLoader.importCaffe()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
# DO NOT REMOVE the following import statement, although your IDE might say it's unused. It's not!
from caffe._caffe import Solver as SolverBaseClassInfo
......@@ -486,8 +486,8 @@ def resetCaffeProtoModulesvar():
def _caffeProtobufModules():
""" Returns all available Classes of caffe_pb2 in a dictionary """
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
global _caffeprotomodulesvar
if _caffeprotomodulesvar is None:
current_module = sys.modules[proto.__name__]
......
......@@ -9,8 +9,8 @@ def saveSolver(solverdict):
def _import_solver(solverdict):
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
solver = proto.SolverParameter()
for entry in solverdict:
......@@ -67,8 +67,8 @@ def saveNet(netdict):
def _import_dictionary(netdict):
"""fill the ProtoTxt-Net with data from the dictionary"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
net = proto.NetParameter()
for entry in netdict:
......
......@@ -51,8 +51,8 @@ class LeveldbInput:
def verifyConsistency(self):
'''check if all entries have the same channel number and size.
This may take a lot of time since every entry has to be checked.'''
import backend.caffe.path_loader as caffeLoader
caffe = caffeLoader.importCaffe()
from backend.caffe.path_loader import PathLoader
caffe = PathLoader().importCaffe()
if self.getDataCount() < 2:
return True
first = self._getFirstDatum()
......@@ -86,8 +86,8 @@ class LeveldbInput:
return self._db.RangeIter()
def _getFirstDatum(self):
import backend.caffe.path_loader as caffeLoader
caffe = caffeLoader.importCaffe()
from backend.caffe.path_loader import PathLoader
caffe = PathLoader().importCaffe()
iter = self._getIter()
if iter:
for key, value in iter:
......
......@@ -96,8 +96,8 @@ class LmdbInput:
return self._getCurrentDatum()
def _getCurrentDatum(self):
import backend.caffe.path_loader as caffeLoader
caffe = caffeLoader.importCaffe()
from backend.caffe.path_loader import PathLoader
caffe = PathLoader().importCaffe()
if self._cursor:
raw_datum = self._cursor.value()
datum = caffe.proto.caffe_pb2.Datum()
......
......@@ -6,7 +6,6 @@ import logging
from PyQt5.QtNetwork import QTcpServer, QHostAddress
#from backend.caffe import path_loader
from backend.caffe.check_hardware import checkHardware
from backend.barista.session.session_utils import State
from backend.networking.protocol import Protocol
......
......@@ -12,8 +12,8 @@ class SolverPropertyInfoBuilder(data.PropertyInfoBuilder):
return PropertyInfo(name, protoparameter, uri)
def buildRootInfo(self, name=""):
""" Builds "solver" info object of the state dictionary """
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
protosolver = proto.SolverParameter()
descr = info.ParameterGroupDescriptor(protosolver)
......
# from https://www.eriksmistad.no/visualizing-learned-features-of-a-caffe-neural-network/
# test code, will be changed
import backend.caffe.path_loader as caffeLoader
caffe = caffeLoader.importCaffe()
from backend.caffe.path_loader import PathLoader
caffe = PathLoader().importCaffe()
from gui.main_window.docks.weight_visualization import visualize_weights
......
......@@ -7,8 +7,8 @@ import numpy as np
def loadNetParameter(caffemodel):
""" Return a NetParameter protocol buffer loaded from the caffemodel.
"""
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
net = proto.NetParameter()
try:
......@@ -19,8 +19,8 @@ def loadNetParameter(caffemodel):
pass
def loadNetParamFromString(paramstring):
import backend.caffe.path_loader as caffeLoader
proto = caffeLoader.importProto()
from backend.caffe.path_loader import PathLoader
proto = PathLoader().importProto()
net = proto.NetParameter()
try:
net.ParseFromString(paramstring)
......
......@@ -161,7 +161,7 @@ class UiMainWindow(QtWidgets.QMainWindow):
event.ignore()
def _onProjectChanged(self, project):
from backend.caffe import path_loader
from backend.caffe.path_loader import PathLoader
self.setWindowTitle("Barista - {}".format(project.getProjectName()))
#reload the caffe information
......
......@@ -84,10 +84,10 @@ class StartDialog(QtWidgets.QDialog):
self.loadProject(proj)
def loadProject(self, dir):
import backend.caffe.path_loader as pathLoader
from backend.caffe.path_loader import PathLoader
resetCaffeProtoModulesvar()
proj = Project(dir)
pathLoader.changeCaffeVersion(proj.getCaffeVersion())
PathLoader().changeCaffeVersion(proj.getCaffeVersion())
win = main_window.UiMainWindow(self._defAct)
self._defAct.setProject(proj)
self.__switchMainWindow(win)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment