Merge pull request #1047 from return42/redis-lib

Add a redis library to generalize DB functions we need in SearXNG.
This commit is contained in:
Alexandre Flament 2022-06-06 10:59:11 +02:00 committed by GitHub
commit ea0cddba0b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 268 additions and 45 deletions

View file

@ -139,6 +139,7 @@ intersphinx_mapping = {
"jinja": ("https://jinja.palletsprojects.com/", None), "jinja": ("https://jinja.palletsprojects.com/", None),
"linuxdoc" : ("https://return42.github.io/linuxdoc/", None), "linuxdoc" : ("https://return42.github.io/linuxdoc/", None),
"sphinx" : ("https://www.sphinx-doc.org/en/master/", None), "sphinx" : ("https://www.sphinx-doc.org/en/master/", None),
"redis": ('https://redis.readthedocs.io/en/stable/', None),
} }
issues_github_path = "searxng/searxng" issues_github_path = "searxng/searxng"

View file

@ -0,0 +1,8 @@
.. _searx.redis:
=============
Redis Library
=============
.. automodule:: searx.redislib
:members:

View file

@ -13,11 +13,11 @@ Enable the plugin in ``settings.yml``:
- ``redis.url: ...`` check the value, see :ref:`settings redis` - ``redis.url: ...`` check the value, see :ref:`settings redis`
""" """
import hmac
import re import re
from flask import request from flask import request
from searx.shared import redisdb from searx.shared import redisdb
from searx.redislib import incr_sliding_window
name = "Request limiter" name = "Request limiter"
description = "Limit the number of request" description = "Limit the number of request"
@ -36,8 +36,9 @@ re_bot = re.compile(
) )
def is_accepted_request(inc_get_counter) -> bool: def is_accepted_request() -> bool:
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
redis_client = redisdb.client()
user_agent = request.headers.get('User-Agent', '') user_agent = request.headers.get('User-Agent', '')
x_forwarded_for = request.headers.get('X-Forwarded-For', '') x_forwarded_for = request.headers.get('X-Forwarded-For', '')
@ -47,71 +48,46 @@ def is_accepted_request(inc_get_counter) -> bool:
return True return True
if request.path == '/search': if request.path == '/search':
c_burst = inc_get_counter(interval=20, keys=[b'IP limit, burst', x_forwarded_for]) c_burst = incr_sliding_window(redis_client, 'IP limit, burst' + x_forwarded_for, 20)
c_10min = inc_get_counter(interval=600, keys=[b'IP limit, 10 minutes', x_forwarded_for]) c_10min = incr_sliding_window(redis_client, 'IP limit, 10 minutes' + x_forwarded_for, 600)
if c_burst > 15 or c_10min > 150: if c_burst > 15 or c_10min > 150:
logger.debug("to many request") # pylint: disable=undefined-variable
return False return False
if re_bot.match(user_agent): if re_bot.match(user_agent):
logger.debug("detected bot") # pylint: disable=undefined-variable
return False return False
if len(request.headers.get('Accept-Language', '').strip()) == '': if len(request.headers.get('Accept-Language', '').strip()) == '':
logger.debug("missing Accept-Language") # pylint: disable=undefined-variable
return False return False
if request.headers.get('Connection') == 'close': if request.headers.get('Connection') == 'close':
logger.debug("got Connection=close") # pylint: disable=undefined-variable
return False return False
accept_encoding_list = [l.strip() for l in request.headers.get('Accept-Encoding', '').split(',')] accept_encoding_list = [l.strip() for l in request.headers.get('Accept-Encoding', '').split(',')]
if 'gzip' not in accept_encoding_list or 'deflate' not in accept_encoding_list: if 'gzip' not in accept_encoding_list or 'deflate' not in accept_encoding_list:
logger.debug("suspicious Accept-Encoding") # pylint: disable=undefined-variable
return False return False
if 'text/html' not in request.accept_mimetypes: if 'text/html' not in request.accept_mimetypes:
logger.debug("Accept-Encoding misses text/html") # pylint: disable=undefined-variable
return False return False
if request.args.get('format', 'html') != 'html': if request.args.get('format', 'html') != 'html':
c = inc_get_counter(interval=3600, keys=[b'API limit', x_forwarded_for]) c = incr_sliding_window(redis_client, 'API limit' + x_forwarded_for, 3600)
if c > 4: if c > 4:
logger.debug("API limit exceeded") # pylint: disable=undefined-variable
return False return False
return True return True
def create_inc_get_counter(redis_client, secret_key_bytes):
lua_script = """
local slidingWindow = KEYS[1]
local key = KEYS[2]
local now = tonumber(redis.call('TIME')[1])
local id = redis.call('INCR', 'counter')
if (id > 2^46)
then
redis.call('SET', 'count', 0)
end
redis.call('ZREMRANGEBYSCORE', key, 0, now - slidingWindow)
redis.call('ZADD', key, now, id)
local result = redis.call('ZCOUNT', key, 0, now+1)
redis.call('EXPIRE', key, slidingWindow)
return result
"""
script_sha = redis_client.script_load(lua_script)
def inc_get_counter(interval, keys):
m = hmac.new(secret_key_bytes, digestmod='sha256')
for k in keys:
m.update(bytes(str(k), encoding='utf-8') or b'')
m.update(b"\0")
key = m.digest()
return redis_client.evalsha(script_sha, 2, interval, key)
return inc_get_counter
def create_pre_request(get_aggregation_count):
def pre_request(): def pre_request():
if not is_accepted_request(get_aggregation_count): if not is_accepted_request():
return '', 429 return '', 429
return None return None
return pre_request
def init(app, settings): def init(app, settings):
if not settings['server']['limiter']: if not settings['server']['limiter']:
@ -122,8 +98,5 @@ def init(app, settings):
logger.error("init limiter DB failed!!!") # pylint: disable=undefined-variable logger.error("init limiter DB failed!!!") # pylint: disable=undefined-variable
return False return False
redis_client = redisdb.client() app.before_request(pre_request)
secret_key_bytes = bytes(settings['server']['secret_key'], encoding='utf-8')
inc_get_counter = create_inc_get_counter(redis_client, secret_key_bytes)
app.before_request(create_pre_request(inc_get_counter))
return True return True

241
searx/redislib.py Normal file
View file

@ -0,0 +1,241 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# lint: pylint
"""A collection of convenient functions and redis/lua scripts.
This code was partial inspired by the `Bullet-Proofing Lua Scripts in RedisPy`_
article.
.. _Bullet-Proofing Lua Scripts in RedisPy:
https://redis.com/blog/bullet-proofing-lua-scripts-in-redispy/
"""
import hmac
from searx import get_setting
LUA_SCRIPT_STORAGE = {}
"""A global dictionary to cache client's ``Script`` objects, used by
:py:obj:`lua_script_storage`"""
def lua_script_storage(client, script):
"""Returns a redis :py:obj:`Script
<redis.commands.core.CoreCommands.register_script>` instance.
Due to performance reason the ``Script`` object is instantiated only once
for a client (``client.register_script(..)``) and is cached in
:py:obj:`LUA_SCRIPT_STORAGE`.
"""
# redis connection can be closed, lets use the id() of the redis connector
# as key in the script-storage:
client_id = id(client)
if LUA_SCRIPT_STORAGE.get(client_id) is None:
LUA_SCRIPT_STORAGE[client_id] = {}
if LUA_SCRIPT_STORAGE[client_id].get(script) is None:
LUA_SCRIPT_STORAGE[client_id][script] = client.register_script(script)
return LUA_SCRIPT_STORAGE[client_id][script]
PURGE_BY_PREFIX = """
local prefix = tostring(ARGV[1])
for i, name in ipairs(redis.call('KEYS', prefix .. '*')) do
redis.call('EXPIRE', name, 0)
end
"""
def purge_by_prefix(client, prefix: str = "SearXNG_"):
"""Purge all keys with ``prefix`` from database.
Queries all keys in the database by the given prefix and set expire time to
zero. The default prefix will drop all keys which has been set by SearXNG
(drops SearXNG schema entirely from database).
The implementation is the lua script from string :py:obj:`PURGE_BY_PREFIX`.
The lua script uses EXPIRE_ instead of DEL_: if there are a lot keys to
delete and/or their values are big, `DEL` could take more time and blocks
the command loop while `EXPIRE` turns back immediate.
:param prefix: prefix of the key to delete (default: ``SearXNG_``)
:type name: str
.. _EXPIRE: https://redis.io/commands/expire/
.. _DEL: https://redis.io/commands/del/
"""
script = lua_script_storage(client, PURGE_BY_PREFIX)
script(args=[prefix])
def secret_hash(name: str):
"""Creates a hash of the ``name``.
Combines argument ``name`` with the ``secret_key`` from :ref:`settings
server`. This function can be used to get a more anonymised name of a Redis
KEY.
:param name: the name to create a secret hash for
:type name: str
"""
m = hmac.new(bytes(name, encoding='utf-8'), digestmod='sha256')
m.update(bytes(get_setting('server.secret_key'), encoding='utf-8'))
return m.hexdigest()
INCR_COUNTER = """
local limit = tonumber(ARGV[1])
local expire = tonumber(ARGV[2])
local c_name = KEYS[1]
local c = redis.call('GET', c_name)
if not c then
c = redis.call('INCR', c_name)
if expire > 0 then
redis.call('EXPIRE', c_name, expire)
end
else
c = tonumber(c)
if limit == 0 or c < limit then
c = redis.call('INCR', c_name)
end
end
return c
"""
def incr_counter(client, name: str, limit: int = 0, expire: int = 0):
"""Increment a counter and return the new value.
If counter with redis key ``SearXNG_counter_<name>`` does not exists it is
created with initial value 1 returned. The replacement ``<name>`` is a
*secret hash* of the value from argument ``name`` (see
:py:func:`secret_hash`).
The implementation of the redis counter is the lua script from string
:py:obj:`INCR_COUNTER`.
:param name: name of the counter
:type name: str
:param expire: live-time of the counter in seconds (default ``None`` means
infinite).
:type expire: int / see EXPIRE_
:param limit: limit where the counter stops to increment (default ``None``)
:type limit: int / limit is 2^64 see INCR_
:return: value of the incremented counter
:type return: int
.. _EXPIRE: https://redis.io/commands/expire/
.. _INCR: https://redis.io/commands/incr/
A simple demo of a counter with expire time and limit::
>>> for i in range(6):
... i, incr_counter(client, "foo", 3, 5) # max 3, duration 5 sec
... time.sleep(1) # from the third call on max has been reached
...
(0, 1)
(1, 2)
(2, 3)
(3, 3)
(4, 3)
(5, 1)
"""
script = lua_script_storage(client, INCR_COUNTER)
name = "SearXNG_counter_" + secret_hash(name)
c = script(args=[limit, expire], keys=[name])
return c
def drop_counter(client, name):
"""Drop counter with redis key ``SearXNG_counter_<name>``
The replacement ``<name>`` is a *secret hash* of the value from argument
``name`` (see :py:func:`incr_counter` and :py:func:`incr_sliding_window`).
"""
name = "SearXNG_counter_" + secret_hash(name)
client.delete(name)
INCR_SLIDING_WINDOW = """
local expire = tonumber(ARGV[1])
local name = KEYS[1]
local current_time = redis.call('TIME')
redis.call('ZREMRANGEBYSCORE', name, 0, current_time[1] - expire)
redis.call('ZADD', name, current_time[1], current_time[1] .. current_time[2])
local result = redis.call('ZCOUNT', name, 0, current_time[1] + 1)
redis.call('EXPIRE', name, expire)
return result
"""
def incr_sliding_window(client, name: str, duration: int):
"""Increment a sliding-window counter and return the new value.
If counter with redis key ``SearXNG_counter_<name>`` does not exists it is
created with initial value 1 returned. The replacement ``<name>`` is a
*secret hash* of the value from argument ``name`` (see
:py:func:`secret_hash`).
:param name: name of the counter
:type name: str
:param duration: live-time of the sliding window in seconds
:typeduration: int
:return: value of the incremented counter
:type return: int
The implementation of the redis counter is the lua script from string
:py:obj:`INCR_SLIDING_WINDOW`. The lua script uses `sorted sets in Redis`_
to implement a sliding window for the redis key ``SearXNG_counter_<name>``
(ZADD_). The current TIME_ is used to score the items in the sorted set and
the time window is moved by removing items with a score lower current time
minus *duration* time (ZREMRANGEBYSCORE_).
The EXPIRE_ time (the duration of the sliding window) is refreshed on each
call (incrementation) and if there is no call in this duration, the sorted
set expires from the redis DB.
The return value is the amount of items in the sorted set (ZCOUNT_), what
means the number of calls in the sliding window.
.. _Sorted sets in Redis:
https://redis.com/ebook/part-1-getting-started/chapter-1-getting-to-know-redis/1-2-what-redis-data-structures-look-like/1-2-5-sorted-sets-in-redis/
.. _TIME: https://redis.io/commands/time/
.. _ZADD: https://redis.io/commands/zadd/
.. _EXPIRE: https://redis.io/commands/expire/
.. _ZREMRANGEBYSCORE: https://redis.io/commands/zremrangebyscore/
.. _ZCOUNT: https://redis.io/commands/zcount/
A simple demo of the sliding window::
>>> for i in range(5):
... incr_sliding_window(client, "foo", 3) # duration 3 sec
... time.sleep(1) # from the third call (second) on the window is moved
...
1
2
3
3
3
>>> time.sleep(3) # wait until expire
>>> incr_sliding_window(client, "foo", 3)
1
"""
script = lua_script_storage(client, INCR_SLIDING_WINDOW)
name = "SearXNG_counter_" + secret_hash(name)
c = script(args=[duration], keys=[name])
return c