[mod] searx.utils: more typing

This commit is contained in:
Alexandre Flament 2022-01-30 22:14:12 +01:00
parent 0eacc46ee3
commit 2d5929cc59

View file

@ -6,6 +6,8 @@
""" """
import re import re
import importlib import importlib
import importlib.util
import types
from typing import Optional, Union, Any, Set, List, Dict, MutableMapping, Tuple, Callable from typing import Optional, Union, Any, Set, List, Dict, MutableMapping, Tuple, Callable
from numbers import Number from numbers import Number
@ -45,8 +47,8 @@ _STORAGE_UNIT_VALUE: Dict[str, int] = {
'KiB': 1000, 'KiB': 1000,
} }
_XPATH_CACHE = {} _XPATH_CACHE: Dict[str, XPath] = {}
_LANG_TO_LC_CACHE = {} _LANG_TO_LC_CACHE: Dict[str, Dict[str, str]] = {}
class _NotSetClass: # pylint: disable=too-few-public-methods class _NotSetClass: # pylint: disable=too-few-public-methods
@ -150,7 +152,7 @@ def html_to_text(html_str: str) -> str:
return s.get_text() return s.get_text()
def extract_text(xpath_results, allow_none: bool = False): def extract_text(xpath_results, allow_none: bool = False) -> Optional[str]:
"""Extract text from a lxml result """Extract text from a lxml result
* if xpath_results is list, extract the text from each result and concat the list * if xpath_results is list, extract the text from each result and concat the list
@ -264,7 +266,9 @@ def extract_url(xpath_results, base_url) -> str:
raise ValueError('Empty url resultset') raise ValueError('Empty url resultset')
url = extract_text(xpath_results) url = extract_text(xpath_results)
if url:
return normalize_url(url, base_url) return normalize_url(url, base_url)
raise ValueError('URL not found')
def dict_subset(dictionnary: MutableMapping, properties: Set[str]) -> Dict: def dict_subset(dictionnary: MutableMapping, properties: Set[str]) -> Dict:
@ -366,7 +370,7 @@ def _get_lang_to_lc_dict(lang_list: List[str]) -> Dict[str, str]:
# babel's get_global contains all sorts of miscellaneous locale and territory related data # babel's get_global contains all sorts of miscellaneous locale and territory related data
# see get_global in: https://github.com/python-babel/babel/blob/master/babel/core.py # see get_global in: https://github.com/python-babel/babel/blob/master/babel/core.py
def _get_from_babel(lang_code, key: str): def _get_from_babel(lang_code: str, key: str):
match = get_global(key).get(lang_code.replace('-', '_')) match = get_global(key).get(lang_code.replace('-', '_'))
# for some keys, such as territory_aliases, match may be a list # for some keys, such as territory_aliases, match may be a list
if isinstance(match, str): if isinstance(match, str):
@ -374,7 +378,7 @@ def _get_from_babel(lang_code, key: str):
return match return match
def _match_language(lang_code, lang_list=[], custom_aliases={}) -> Optional[str]: # pylint: disable=W0102 def _match_language(lang_code: str, lang_list=[], custom_aliases={}) -> Optional[str]: # pylint: disable=W0102
"""auxiliary function to match lang_code in lang_list""" """auxiliary function to match lang_code in lang_list"""
# replace language code with a custom alias if necessary # replace language code with a custom alias if necessary
if lang_code in custom_aliases: if lang_code in custom_aliases:
@ -396,10 +400,12 @@ def _match_language(lang_code, lang_list=[], custom_aliases={}) -> Optional[str]
return new_code return new_code
# try to get the any supported country for this language # try to get the any supported country for this language
return _get_lang_to_lc_dict(lang_list).get(lang_code, None) return _get_lang_to_lc_dict(lang_list).get(lang_code)
def match_language(locale_code, lang_list=[], custom_aliases={}, fallback='en-US') -> str: # pylint: disable=W0102 def match_language( # pylint: disable=W0102
locale_code, lang_list=[], custom_aliases={}, fallback: Optional[str] = 'en-US'
) -> Optional[str]:
"""get the language code from lang_list that best matches locale_code""" """get the language code from lang_list that best matches locale_code"""
# try to get language from given locale_code # try to get language from given locale_code
language = _match_language(locale_code, lang_list, custom_aliases) language = _match_language(locale_code, lang_list, custom_aliases)
@ -437,12 +443,16 @@ def match_language(locale_code, lang_list=[], custom_aliases={}, fallback='en-US
return language or fallback return language or fallback
def load_module(filename: str, module_dir: str): def load_module(filename: str, module_dir: str) -> types.ModuleType:
modname = splitext(filename)[0] modname = splitext(filename)[0]
filepath = join(module_dir, filename) modpath = join(module_dir, filename)
# and https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly # and https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
spec = importlib.util.spec_from_file_location(modname, filepath) spec = importlib.util.spec_from_file_location(modname, modpath)
if not spec:
raise ValueError(f"Error loading '{modpath}' module")
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
if not spec.loader:
raise ValueError(f"Error loading '{modpath}' module")
spec.loader.exec_module(module) spec.loader.exec_module(module)
return module return module
@ -477,7 +487,7 @@ def ecma_unescape(string: str) -> str:
return string return string
def get_string_replaces_function(replaces: Dict[str, str]) -> Callable: def get_string_replaces_function(replaces: Dict[str, str]) -> Callable[[str], str]:
rep = {re.escape(k): v for k, v in replaces.items()} rep = {re.escape(k): v for k, v in replaces.items()}
pattern = re.compile("|".join(rep.keys())) pattern = re.compile("|".join(rep.keys()))