import copy
import os
import sys
import datetime
import inspect
import json
import functools
import platform
import time
from packaging import version as version_mod
from pathlib import Path
import traceback
from typing import Any, cast, List, Optional, Tuple
from typing import Iterable as IterableType
from collections.abc import Iterable
import numpy as np
from pymodaq_utils import logger as logger_module
from pymodaq_utils.config import Config
from pymodaq_utils.warnings import deprecation_msg
from pymodaq_utils.serialize.factory import SerializableFactory, SerializableBase
from importlib import metadata
PackageNotFoundError = metadata.PackageNotFoundError # for use elsewhere
# for use elsewhere
if version_mod.parse(platform.python_version()) >= version_mod.parse('3.9'):
# from version 3.9 the cache decorator is available
from functools import cache
else:
from functools import lru_cache as cache
logger = logger_module.set_logger(logger_module.get_module_name(__file__))
config = Config()
[docs]
class PlotColors:
def __init__(self, colors=config('plotting', 'plot_colors')[:]):
self._internal_counter = -1
self.check_colors(colors)
self._plot_colors = [tuple(color) for color in colors]
[docs]
def copy(self):
return copy.copy(self)
[docs]
def remove(self, item):
self._plot_colors.remove(item)
def __getitem__(self, item: int):
if not isinstance(item, int):
raise TypeError('getter should be an integer')
return tuple(self._plot_colors[item % len(self._plot_colors)])
def __len__(self):
return len(self._plot_colors)
def __iter__(self):
self._internal_counter = -1
return self
def __next__(self):
if self._internal_counter >= len(self) - 1:
raise StopIteration
self._internal_counter += 1
return self[self._internal_counter]
[docs]
def check_colors(self, colors: IterableType):
if not isinstance(colors, Iterable):
raise TypeError('Colors should be a list of 3-tuple 8 bits integer (0-255)')
for color in colors:
self.check_color(color)
[docs]
@staticmethod
def check_color(color: IterableType):
if not isinstance(color, Iterable) and len(color) != 3:
raise TypeError('Colors should be a list of 3-tuple 8 bits integer (0-255)')
for col_val in color:
if not (isinstance(col_val, int) and 0 <= col_val <= 255):
raise TypeError('Colors should be a list of 3-tuple 8 bits integer (0-255)')
plot_colors = PlotColors()
[docs]
def is_64bits():
return sys.maxsize > 2**32
[docs]
def timer(func):
"""Print the runtime of the decorated function"""
@functools.wraps(func)
def wrapper_timer(*args, **kwargs):
start_time = time.perf_counter() # 1
value = func(*args, **kwargs)
end_time = time.perf_counter() # 2
run_time = end_time - start_time # 3
print(f"Finished {func.__name__!r} in {run_time:.4f} secs")
return value
return wrapper_timer
[docs]
def get_version(package_name='pymodaq'):
"""Obtain the package version using the importlib metadata module
"""
return metadata.version(package_name)
[docs]
class JsonConverter:
def __init__(self):
super().__init__()
[docs]
@classmethod
def trusted_types(cls):
return ['float', 'int', 'str', 'datetime', 'date', 'time', 'tuple', 'list', 'bool', 'bytes',
'float64']
[docs]
@classmethod
def istrusted(cls, type_name):
return type_name in cls.trusted_types()
[docs]
@classmethod
def object2json(cls, obj):
dic = dict(module=type(obj).__module__, type=type(obj).__name__, data=repr(obj))
return json.dumps(dic)
[docs]
@classmethod
def json2object(cls, jsonstring):
try:
dic = json.loads(jsonstring)
if isinstance(dic, dict):
if dic['type'] in cls.trusted_types():
return eval(dic['data'])
else:
return dic
else: # pragma: no cover
return dic
except Exception:
return jsonstring
[docs]
def capitalize(string, Nfirst=1):
"""
Returns same string but with first Nfirst letters upper
Parameters
----------
string: (str)
Nfirst: (int)
Returns
-------
str
"""
return string[:Nfirst].upper() + string[Nfirst:]
[docs]
def uncapitalize(string, Nfirst=1):
return string[:Nfirst].lower() + string[Nfirst:]
[docs]
def getLineInfo():
"""get information about where the Exception has been triggered"""
tb = sys.exc_info()[2]
res = ''
for t in traceback.format_tb(tb):
res += t
return res
[docs]
@SerializableFactory.register_decorator()
class ThreadCommand(SerializableBase):
"""Generic object to pass info (command) and data (attribute) between thread or objects using signals
Parameters
----------
command: str
The command to be analysed for further action
attribute: any type
the attribute related to the command. The actual type and value depend on the command and the situation
attributes: deprecated, attribute should be used instead
Attributes
----------
command : str
The command to be analysed for further action
attribute : any type
the attribute related to the command. The actual type and value depend on the command and the situation
args: some variables in a list
kwargs: some variables in a dict
"""
command: str
attribute: Any
args: list
kwargs: dict
def __init__(self, command: str, attribute=None, attributes=None, args=(), kwargs: Optional[dict] = None):
if not isinstance(command, str):
raise TypeError(f'The command in a Threadcommand object should be a string, not a {type(command)}')
self.command = command
if attribute is None and attributes is not None:
deprecation_msg('ThreadCommand signature changed, use attribute in place of attributes')
self.attribute = attributes
self.attributes = attributes
self.attribute = attribute
self.args = args
self.kwargs = {} if kwargs is None else kwargs
def __eq__(self, other: Any) -> bool:
if not isinstance(other, ThreadCommand):
return NotImplemented
return (
self.command == other.command
and self.attribute == other.attribute
and self.args == other.args
and self.kwargs == other.kwargs
)
[docs]
@staticmethod
def serialize(obj: "ThreadCommand") -> bytes: # type: ignore[override]
serialize_factory = SerializableFactory()
byte_string = b""
byte_string += serialize_factory.get_apply_serializer(obj.command)
byte_string += serialize_factory.get_apply_serializer(obj.attribute)
byte_string += serialize_factory.get_apply_serializer(obj.args)
byte_string += serialize_factory.get_apply_serializer(obj.kwargs)
return byte_string
[docs]
@staticmethod
def deserialize(bytes_str: bytes) -> Tuple["ThreadCommand", bytes]:
serialize_factory = SerializableFactory()
command, remaining = cast(
Tuple[str, bytes],
serialize_factory.get_apply_deserializer(bytes_str=bytes_str, only_object=False),
)
attribute, remaining = cast(
Tuple[Any, bytes], serialize_factory.get_apply_deserializer(remaining, False)
)
args, remaining = cast(
Tuple[list, bytes],
serialize_factory.get_apply_deserializer(remaining, False)
)
kwargs, remaining = cast(
Tuple[dict, bytes],
serialize_factory.get_apply_deserializer(remaining, False)
)
return ThreadCommand(command, attribute, args=tuple(args), kwargs=kwargs), remaining
def __repr__(self):
return f'Threadcommand: {self.command} with attribute {self.attribute}'
[docs]
def ensure_ndarray(data):
"""
Make sure data is returned as a numpy array
Parameters
----------
data
Returns
-------
ndarray
"""
if not isinstance(data, np.ndarray):
if isinstance(data, list):
data = np.array(data)
else:
data = np.array([data])
return data
[docs]
def recursive_find_files_extension(ini_path, ext, paths=[]):
with os.scandir(ini_path) as it:
for entry in it:
if os.path.splitext(entry.name)[1][1:] == ext and entry.is_file():
paths.append(entry.path)
elif entry.is_dir():
recursive_find_files_extension(entry.path, ext, paths)
return paths
[docs]
def recursive_find_files(ini_path, exp='make_enum', paths=[],
filters=['build']):
for child in Path(ini_path).iterdir():
if child.is_dir():
recursive_find_files(child, exp, paths, filters)
else:
if exp in child.stem:
if not any([filt in str(child) for filt in filters]):
paths.append(child)
return paths
[docs]
def recursive_find_expr_in_files(ini_path, exp='make_enum', paths=[],
filters=['.git', '.idea', '__pycache__', 'build', 'egg', 'documentation', '.tox'],
replace=False, replace_str=''):
for child in Path(ini_path).iterdir():
if not any(filt in str(child) for filt in filters):
if child.is_dir():
recursive_find_expr_in_files(child, exp, paths, filters, replace=replace, replace_str=replace_str)
else:
try:
found = False
with child.open('r') as f:
replacement = ''
for ind, line in enumerate(f):
if exp in line:
found = True
paths.append([child, ind, line])
if replace:
replacement += line.replace(exp, replace_str)
else:
if replace:
replacement += line
if replace and found:
with child.open('w') as f:
f.write(replacement)
except Exception:
pass
return paths
[docs]
def count_lines(ini_path, count=0, filters=['lextab', 'yacctab','pycache', 'pyc']):
# if Path(ini_path).is_file():
# with Path(ini_path).open('r') as f:
# count += len(f.readlines())
# return count
for child in Path(ini_path).iterdir():
if child.is_dir():
count = count_lines(child, count)
else:
try:
if not any([filt in child.name for filt in filters]):
if '.py' in child.name:
with child.open('r') as f:
count += len(f.readlines())
else:
print(child.stem)
except Exception:
pass
return count
[docs]
def remove_spaces(string):
"""
return a string without any white spaces in it
Parameters
----------
string
Returns
-------
"""
return ''.join(string.split())
[docs]
def rint(x):
"""
almost same as numpy rint function but return an integer
Parameters
----------
x: (float or integer)
Returns
-------
nearest integer
"""
return int(np.rint(x))
[docs]
def elt_as_first_element(elt_list, match_word='Mock'):
if not hasattr(elt_list, '__iter__'):
raise TypeError('elt_list must be an iterable')
if elt_list:
ind_elt = 0
for ind, elt in enumerate(elt_list):
if not isinstance(elt, str):
raise TypeError('elt_list must be a list of str')
if match_word in elt:
ind_elt = ind
break
plugin_match = elt_list[ind_elt]
elt_list.remove(plugin_match)
plugins = [plugin_match]
plugins.extend(elt_list)
else:
plugins = []
return plugins
[docs]
def elt_as_first_element_dicts(elt_list, match_word='Mock', key='name'):
if not hasattr(elt_list, '__iter__'):
raise TypeError('elt_list must be an iterable')
if elt_list:
ind_elt = 0
for ind, elt in enumerate(elt_list):
if not isinstance(elt, dict):
raise TypeError('elt_list must be a list of dicts')
if match_word in elt[key]:
ind_elt = ind
break
plugin_match = elt_list[ind_elt]
elt_list.remove(plugin_match)
plugins = [plugin_match]
plugins.extend(elt_list)
else:
plugins = []
return plugins
[docs]
def find_keys_from_val(dict_tmp: dict, val: object):
"""Returns the keys from a dict if its value is matching val"""
return [k for k, v in dict_tmp.items() if v == val]
[docs]
def find_object_if_matched_attr_name_val(obj, attr_name, attr_value):
"""check if an attribute key/value pair match in a given object
Parameters
----------
obj: object
attr_name: str
attribute name to look for in the object
attr_value: object
value to match
Returns
-------
bool: True if the key/value pair has been found in dict_tmp
"""
if hasattr(obj, attr_name):
if getattr(obj, attr_name) == attr_value:
return True
return False
[docs]
def find_objects_in_list_from_attr_name_val(objects: List[object], attr_name: str,
attr_value: object, return_first=True):
""" lookup within a list of objects. Look for the objects within the list which has the correct attribute name,
value pair
Parameters
----------
objects: list
list of objects
attr_name: str
attribute name to look for in the object
attr_value: object
value to match
return_first: bool
if True return the first objects found in the list else all the objects matching
Returns
-------
list of tuple(object, int): object and index or list of object and indexes
"""
selection = []
obj = None
for ind, obj_tmp in enumerate(objects):
if find_object_if_matched_attr_name_val(obj_tmp, attr_name, attr_value):
obj = obj_tmp
if return_first:
break
else:
selection.append((obj_tmp, ind))
if obj is None:
if return_first:
return None, -1
else:
return []
else:
if return_first:
return obj, ind
else:
return selection
[docs]
def find_dict_if_matched_key_val(dict_tmp, key, value):
"""
check if a key/value pair match in a given dictionary
Parameters
----------
dict_tmp: (dict) the dictionary to be tested
key: (str) a key string to look for in dict_tmp
value: (object) any python object
Returns
-------
bool: True if the key/value pair has been found in dict_tmp
"""
if key in dict_tmp:
if dict_tmp[key] == value:
return True
return False
[docs]
def find_dicts_in_list_from_key_val(dicts, key, value):
""" lookup within a list of dicts. Look for the dicts within the list which have the correct key, value pair
Parameters
----------
dicts: (list) list of dictionnaries
key: (str) specific key to look for in each dict
value: value to match
Returns
-------
dict: if found otherwise returns None
"""
selection = []
for ind, dict_tmp in enumerate(dicts):
if find_dict_if_matched_key_val(dict_tmp, key, value):
selection.append(dict_tmp)
return selection
[docs]
def find_dict_in_list_from_key_val(dicts, key, value, return_index=False):
""" lookup within a list of dicts. Look for the dict within the list which has the correct key, value pair
Parameters
----------
dicts: (list) list of dictionnaries
key: (str) specific key to look for in each dict
value: value to match
Returns
-------
dict: if found otherwise returns None
"""
for ind, dict_tmp in enumerate(dicts):
if find_dict_if_matched_key_val(dict_tmp, key, value):
if return_index:
return dict_tmp, ind
else:
return dict_tmp
if return_index:
return None, -1
else:
return None
[docs]
def get_entrypoints(group='pymodaq.plugins') -> List[metadata.EntryPoint]:
""" Get the list of modules defined from a group entry point
Because of evolution in the package, one or another of the forms below may be deprecated.
We start from the newer way down to the older
Parameters
----------
group: str
the name of the group
"""
try:
discovered_entrypoints = metadata.entry_points(group=group)
except TypeError:
try:
discovered_entrypoints = metadata.entry_points().select(group=group)
except AttributeError:
discovered_entrypoints = metadata.entry_points().get(group, [])
if isinstance(discovered_entrypoints, tuple): # API for python > 3.8
discovered_entrypoints = list(discovered_entrypoints)
if not isinstance(discovered_entrypoints, list):
discovered_entrypoints = list(discovered_entrypoints)
return discovered_entrypoints
[docs]
def check_vals_in_iterable(iterable1, iterable2):
assert len(iterable1) == len(iterable2)
iterable1 = list(iterable1) # so the assertion below is valid for any kind of iterable, list, tuple, ndarray...
iterable2 = list(iterable2)
for val1, val2 in zip(iterable1, iterable2):
assert val1 == val2
[docs]
def caller_name(skip=2):
"""Get a name of a caller in the format module.class.method
`skip` specifies how many levels of stack to skip while getting caller
name. skip=1 means "who calls me", skip=2 "who calls my caller" etc.
An empty string is returned if skipped levels exceed stack height
"""
stack = inspect.stack()
start = 0 + skip
if len(stack) < start + 1:
return ''
parentframe = stack[start][0]
name = []
module = inspect.getmodule(parentframe)
# `modname` can be None when frame is executed directly in console
# TODO(techtonik): consider using __main__
if module:
name.append(module.__name__)
# detect classname
if 'self' in parentframe.f_locals:
# I don't know any way to detect call from the object method
# XXX: there seems to be no way to detect static method call - it will
# be just a function call
name.append(parentframe.f_locals['self'].__class__.__name__)
codename = parentframe.f_code.co_name
if codename != '<module>': # top level usually
name.append(codename) # function or a method
del parentframe
return ".".join(name)
[docs]
def zeros_aligned(n, align, dtype=np.uint32):
"""
Get aligned memory array wih alignment align.
Parameters
----------
n: (int) length in dtype bytes of memory
align: (int) memory alignment
dtype: (numpy.dtype) type of the stored memory elements
Returns
-------
"""
dtype = np.dtype(dtype)
nbytes = n * dtype.itemsize
buff = np.zeros(nbytes + align, dtype=np.uint8)
start_index = -buff.ctypes.data % align
return buff[start_index:start_index + nbytes].view(dtype)
# ########################
# #File management
[docs]
def get_new_file_name(base_path=Path(config('data_saving', 'h5file', 'save_path')), base_name='tttr_data'):
if isinstance(base_path, str):
base_path = Path(base_path)
today = datetime.datetime.now()
date = today.strftime('%Y%m%d')
year = today.strftime('%Y')
year_dir = base_path.joinpath(year)
if not year_dir.is_dir():
year_dir.mkdir()
curr_dir = base_path.joinpath(year, date)
if not curr_dir.is_dir():
curr_dir.mkdir()
files = []
for entry in curr_dir.iterdir():
if entry.name.startswith(base_name) and entry.is_file():
files.append(entry.stem)
files.sort()
if not files:
index = 0
else:
index = int(files[-1][-3:]) + 1
file = f'{base_name}_{index:03d}'
return file, curr_dir
if __name__ == '__main__':
#plugins = get_plugins() # pragma: no cover
#extensions = get_extension()
#models = get_models()
#count = count_lines('C:\\Users\\weber\\Labo\\Programmes Python\\PyMoDAQ_Git\\pymodaq\src')
# import license
# mit = license.find('MIT')
#
paths = recursive_find_expr_in_files(r'C:\Users\weber\Labo\ProgrammesPython\PyMoDAQ_Git',
exp="'multiaxes'",
paths=[],
filters=['.git', '.idea', '__pycache__', 'build', 'egg', 'documentation',
'.tox',],
replace=False,
replace_str="pymodaq.utils")
#get_version()
pass
# paths = recursive_find_files('C:\\Users\\weber\\Labo\\Programmes Python\\PyMoDAQ_Git',
# exp='VERSION', paths=[])
# import version
# for file in paths:
# with open(str(file), 'r') as f:
# v = version.Version(f.read())
# v.minor += 1
# v.patch = 0
# with open(str(file), 'w') as f:
# f.write(str(v))
# for file in paths:
# with open(str(file), 'w') as f:
# f.write(mit.render(name='Sebastien Weber', email='sebastien.weber@cemes.fr'))