# SPDX-License-Identifier: AGPL-3.0-or-later
# lint: pylint
# pylint: disable=missing-module-docstring, missing-function-docstring, global-statement

import asyncio
import threading
import concurrent.futures
from timeit import default_timer

import httpx
import h2.exceptions

from .network import get_network, initialize
from .client import get_loop
from .raise_for_httperror import raise_for_httperror

# queue.SimpleQueue: Support Python 3.6
try:
    from queue import SimpleQueue
except ImportError:
    from queue import Empty
    from collections import deque

    class SimpleQueue:
        """Minimal backport of queue.SimpleQueue"""

        def __init__(self):
            self._queue = deque()
            self._count = threading.Semaphore(0)

        def put(self, item):
            self._queue.append(item)
            self._count.release()

        def get(self):
            if not self._count.acquire(True):  #pylint: disable=consider-using-with
                raise Empty
            return self._queue.popleft()


THREADLOCAL = threading.local()
"""Thread-local data is data for thread specific values."""

def reset_time_for_thread():
    global THREADLOCAL
    THREADLOCAL.total_time = 0


def get_time_for_thread():
    """returns thread's total time or None"""
    global THREADLOCAL
    return THREADLOCAL.__dict__.get('total_time')


def set_timeout_for_thread(timeout, start_time=None):
    global THREADLOCAL
    THREADLOCAL.timeout = timeout
    THREADLOCAL.start_time = start_time


def set_context_network_name(network_name):
    global THREADLOCAL
    THREADLOCAL.network = get_network(network_name)


def get_context_network():
    """If set return thread's network.

    If unset, return value from :py:obj:`get_network`.
    """
    global THREADLOCAL
    return THREADLOCAL.__dict__.get('network') or get_network()


def request(method, url, **kwargs):
    """same as requests/requests/api.py request(...)"""
    global THREADLOCAL
    time_before_request = default_timer()

    # timeout (httpx)
    if 'timeout' in kwargs:
        timeout = kwargs['timeout']
    else:
        timeout = getattr(THREADLOCAL, 'timeout', None)
        if timeout is not None:
            kwargs['timeout'] = timeout

    # 2 minutes timeout for the requests without timeout
    timeout = timeout or 120

    # ajdust actual timeout
    timeout += 0.2  # overhead
    start_time = getattr(THREADLOCAL, 'start_time', time_before_request)
    if start_time:
        timeout -= default_timer() - start_time

    # raise_for_error
    check_for_httperror = True
    if 'raise_for_httperror' in kwargs:
        check_for_httperror = kwargs['raise_for_httperror']
        del kwargs['raise_for_httperror']

    # requests compatibility
    if isinstance(url, bytes):
        url = url.decode()

    # network
    network = get_context_network()

    # do request
    future = asyncio.run_coroutine_threadsafe(network.request(method, url, **kwargs), get_loop())
    try:
        response = future.result(timeout)
    except concurrent.futures.TimeoutError as e:
        raise httpx.TimeoutException('Timeout', request=None) from e

    # requests compatibility
    # see also https://www.python-httpx.org/compatibility/#checking-for-4xx5xx-responses
    response.ok = not response.is_error

    # update total_time.
    # See get_time_for_thread() and reset_time_for_thread()
    if hasattr(THREADLOCAL, 'total_time'):
        time_after_request = default_timer()
        THREADLOCAL.total_time += time_after_request - time_before_request

    # raise an exception
    if check_for_httperror:
        raise_for_httperror(response)

    return response


def get(url, **kwargs):
    kwargs.setdefault('allow_redirects', True)
    return request('get', url, **kwargs)


def options(url, **kwargs):
    kwargs.setdefault('allow_redirects', True)
    return request('options', url, **kwargs)


def head(url, **kwargs):
    kwargs.setdefault('allow_redirects', False)
    return request('head', url, **kwargs)


def post(url, data=None, **kwargs):
    return request('post', url, data=data, **kwargs)


def put(url, data=None, **kwargs):
    return request('put', url, data=data, **kwargs)


def patch(url, data=None, **kwargs):
    return request('patch', url, data=data, **kwargs)


def delete(url, **kwargs):
    return request('delete', url, **kwargs)

async def stream_chunk_to_queue(network, queue, method, url, **kwargs):
    try:
        async with network.stream(method, url, **kwargs) as response:
            queue.put(response)
            async for chunk in response.aiter_bytes(65536):
                if len(chunk) > 0:
                    queue.put(chunk)
    except (httpx.HTTPError, OSError, h2.exceptions.ProtocolError) as e:
        queue.put(e)
    finally:
        queue.put(None)


def stream(method, url, **kwargs):
    """Replace httpx.stream.

    Usage:
    stream = poolrequests.stream(...)
    response = next(stream)
    for chunk in stream:
        ...

    httpx.Client.stream requires to write the httpx.HTTPTransport version of the
    the httpx.AsyncHTTPTransport declared above.
    """
    queue = SimpleQueue()
    future = asyncio.run_coroutine_threadsafe(
        stream_chunk_to_queue(get_network(), queue, method, url, **kwargs),
        get_loop()
    )
    chunk_or_exception = queue.get()
    while chunk_or_exception is not None:
        if isinstance(chunk_or_exception, Exception):
            raise chunk_or_exception
        yield chunk_or_exception
        chunk_or_exception = queue.get()
    return future.result()