304 lines
10 KiB
Python
304 lines
10 KiB
Python
import asyncio
|
|
import collections
|
|
import threading
|
|
from http import HTTPStatus
|
|
from io import BytesIO
|
|
from typing import Any, Coroutine, Deque, Iterable, Optional, TypeVar
|
|
from typing import cast as typing_cast
|
|
|
|
from .asgi_typing import HTTPScope, ASGIApp, ReceiveEvent, SendEvent
|
|
from .wsgi_typing import Environ, StartResponse, IterableChunks
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class defaultdict(dict):
|
|
def __init__(self, default_factory, *args, **kwargs) -> None:
|
|
self.default_factory = default_factory
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def __missing__(self, key):
|
|
return self.default_factory(key)
|
|
|
|
|
|
StatusStringMapping = defaultdict(
|
|
lambda status: f"{status} Unknown Status Code",
|
|
{status.value: f"{status.value} {status.phrase}" for status in HTTPStatus},
|
|
)
|
|
|
|
|
|
class AsyncEvent:
|
|
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
|
|
self.loop = loop
|
|
self.__waiters: Deque[asyncio.Future] = collections.deque()
|
|
self.__nowait = False
|
|
|
|
def _set(self, message: Any) -> None:
|
|
for future in filter(lambda f: not f.done(), self.__waiters):
|
|
future.set_result(message)
|
|
|
|
def set(self, message: Any) -> None:
|
|
self.loop.call_soon_threadsafe(self._set, message)
|
|
|
|
async def wait(self) -> Any:
|
|
if self.__nowait:
|
|
return None
|
|
|
|
future = self.loop.create_future()
|
|
self.__waiters.append(future)
|
|
try:
|
|
result = await future
|
|
return result
|
|
finally:
|
|
self.__waiters.remove(future)
|
|
|
|
def set_nowait(self) -> None:
|
|
self.__nowait = True
|
|
|
|
|
|
class SyncEvent:
|
|
def __init__(self) -> None:
|
|
self.__write_event = threading.Event()
|
|
self.__message: Any = None
|
|
|
|
def set(self, message: Any) -> None:
|
|
self.__message = message
|
|
self.__write_event.set()
|
|
|
|
def wait(self) -> Any:
|
|
self.__write_event.wait()
|
|
self.__write_event.clear()
|
|
message, self.__message = self.__message, None
|
|
return message
|
|
|
|
|
|
def build_scope(environ: Environ) -> HTTPScope:
|
|
headers = [
|
|
(
|
|
(key[5:] if key.startswith("HTTP_") else key)
|
|
.lower()
|
|
.replace("_", "-")
|
|
.encode("latin-1"),
|
|
value.encode("latin-1"), # type: ignore
|
|
)
|
|
for key, value in environ.items()
|
|
if (
|
|
key.startswith("HTTP_")
|
|
and key not in ("HTTP_CONTENT_TYPE", "HTTP_CONTENT_LENGTH")
|
|
)
|
|
or key in ("CONTENT_TYPE", "CONTENT_LENGTH")
|
|
]
|
|
|
|
root_path = environ.get("SCRIPT_NAME", "").encode("latin1").decode("utf8")
|
|
path = root_path + environ.get("PATH_INFO", "").encode("latin1").decode("utf8")
|
|
|
|
scope: HTTPScope = {
|
|
"wsgi_environ": environ, # type: ignore a2wsgi
|
|
"type": "http",
|
|
"asgi": {"version": "3.0", "spec_version": "2.5"},
|
|
"http_version": environ.get("SERVER_PROTOCOL", "http/1.0").split("/")[1],
|
|
"method": environ["REQUEST_METHOD"],
|
|
"scheme": environ.get("wsgi.url_scheme", "http"),
|
|
"path": path,
|
|
"query_string": environ.get("QUERY_STRING", "").encode("ascii"),
|
|
"root_path": root_path,
|
|
"server": (environ["SERVER_NAME"], int(environ["SERVER_PORT"])),
|
|
"headers": headers,
|
|
"extensions": {},
|
|
}
|
|
if environ.get("REMOTE_ADDR") and environ.get("REMOTE_PORT"):
|
|
client = (environ.get("REMOTE_ADDR", ""), int(environ.get("REMOTE_PORT", "0")))
|
|
scope["client"] = client
|
|
|
|
return scope
|
|
|
|
|
|
class ASGIMiddleware:
|
|
"""
|
|
Convert ASGIApp to WSGIApp.
|
|
|
|
wait_time: After the http response ends, the maximum time to wait for the ASGI app to run.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
app: ASGIApp,
|
|
wait_time: Optional[float] = None,
|
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
) -> None:
|
|
self.app = app
|
|
if loop is None:
|
|
loop = asyncio.new_event_loop()
|
|
loop_threading = threading.Thread(target=loop.run_forever, daemon=True)
|
|
loop_threading.start()
|
|
self.loop = loop
|
|
self.wait_time = wait_time
|
|
|
|
def __call__(
|
|
self, environ: Environ, start_response: StartResponse
|
|
) -> Iterable[bytes]:
|
|
return ASGIResponder(self.app, self.loop, self.wait_time)(
|
|
environ, start_response
|
|
)
|
|
|
|
|
|
class ASGIResponder:
|
|
def __init__(
|
|
self,
|
|
app: ASGIApp,
|
|
loop: asyncio.AbstractEventLoop,
|
|
wait_time: Optional[float] = None,
|
|
) -> None:
|
|
self.app = app
|
|
self.loop = loop
|
|
self.wait_time = wait_time
|
|
|
|
self.sync_event = SyncEvent()
|
|
self.sync_event_set_lock: asyncio.Lock
|
|
|
|
self.receive_event = AsyncEvent(loop)
|
|
self.send_event = AsyncEvent(loop)
|
|
|
|
def _init_async_lock():
|
|
self.sync_event_set_lock = asyncio.Lock()
|
|
|
|
loop.call_soon_threadsafe(_init_async_lock)
|
|
|
|
self.asgi_done = threading.Event()
|
|
self.wsgi_should_stop: bool = False
|
|
|
|
async def asgi_receive(self) -> ReceiveEvent:
|
|
await self.sync_event_set_lock.acquire()
|
|
self.sync_event.set({"type": "receive"})
|
|
return await self.receive_event.wait()
|
|
|
|
async def asgi_send(self, message: SendEvent) -> None:
|
|
await self.sync_event_set_lock.acquire()
|
|
self.sync_event.set(message)
|
|
await self.send_event.wait()
|
|
|
|
def asgi_done_callback(self, future: asyncio.Future) -> None:
|
|
try:
|
|
exception = future.exception()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
else:
|
|
if exception is not None:
|
|
task = asyncio.create_task(self.sync_event_set_lock.acquire())
|
|
task.add_done_callback(
|
|
lambda _: self.sync_event.set(
|
|
{
|
|
"type": "a2wsgi.error",
|
|
"exception": (
|
|
type(exception),
|
|
exception,
|
|
exception.__traceback__,
|
|
),
|
|
}
|
|
)
|
|
)
|
|
finally:
|
|
self.asgi_done.set()
|
|
|
|
async def start_asgi_app(self, environ: Environ) -> asyncio.Task:
|
|
run_asgi: asyncio.Task = self.loop.create_task(
|
|
typing_cast(
|
|
Coroutine[None, None, None],
|
|
self.app(build_scope(environ), self.asgi_receive, self.asgi_send),
|
|
)
|
|
)
|
|
run_asgi.add_done_callback(self.asgi_done_callback)
|
|
return run_asgi
|
|
|
|
def execute_in_loop(self, coro: Coroutine[None, None, T]) -> T:
|
|
return asyncio.run_coroutine_threadsafe(coro, self.loop).result()
|
|
|
|
def __call__(
|
|
self, environ: Environ, start_response: StartResponse
|
|
) -> IterableChunks:
|
|
read_count: int = 0
|
|
body = environ["wsgi.input"] or BytesIO()
|
|
content_length = int(environ.get("CONTENT_LENGTH", None) or 0)
|
|
receive_eof = False
|
|
body_sent = False
|
|
|
|
asgi_task = self.execute_in_loop(self.start_asgi_app(environ))
|
|
# activate loop
|
|
self.loop.call_soon_threadsafe(lambda: None)
|
|
|
|
while True:
|
|
message = self.sync_event.wait()
|
|
self.loop.call_soon_threadsafe(self.sync_event_set_lock.release)
|
|
message_type = message["type"]
|
|
|
|
if message_type == "http.response.start":
|
|
start_response(
|
|
StatusStringMapping[message["status"]],
|
|
[
|
|
(
|
|
name.strip().decode("latin1"),
|
|
value.strip().decode("latin1"),
|
|
)
|
|
for name, value in message["headers"]
|
|
],
|
|
None,
|
|
)
|
|
self.send_event.set(None)
|
|
elif message_type == "http.response.body":
|
|
yield message.get("body", b"")
|
|
body_sent = True
|
|
self.wsgi_should_stop = not message.get("more_body", False)
|
|
self.send_event.set(None)
|
|
elif message_type == "http.response.disconnect":
|
|
self.wsgi_should_stop = True
|
|
self.send_event.set(None)
|
|
# ASGI application error
|
|
elif message_type == "a2wsgi.error":
|
|
if body_sent:
|
|
raise message["exception"][1].with_traceback(
|
|
message["exception"][2]
|
|
)
|
|
start_response(
|
|
"500 Internal Server Error",
|
|
[
|
|
("Content-Type", "text/plain; charset=utf-8"),
|
|
("Content-Length", "28"),
|
|
],
|
|
message["exception"],
|
|
)
|
|
yield b"Server got itself in trouble"
|
|
self.wsgi_should_stop = True
|
|
elif message_type == "receive":
|
|
read_size = min(65536, content_length - read_count)
|
|
if read_size == 0: # No more body, so don't read anymore
|
|
if not receive_eof:
|
|
self.receive_event.set(
|
|
{"type": "http.request", "body": b"", "more_body": False}
|
|
)
|
|
receive_eof = True
|
|
else:
|
|
pass # let `await receive()` wait
|
|
else:
|
|
data: bytes = body.read(read_size)
|
|
read_count += len(data)
|
|
more_body = read_count < content_length
|
|
self.receive_event.set(
|
|
{"type": "http.request", "body": data, "more_body": more_body}
|
|
)
|
|
if more_body is False:
|
|
receive_eof = True
|
|
else:
|
|
raise RuntimeError(f"Unknown message type: {message_type}")
|
|
|
|
if self.wsgi_should_stop:
|
|
self.receive_event.set({"type": "http.disconnect"})
|
|
break
|
|
|
|
if self.asgi_done.is_set():
|
|
break
|
|
|
|
# HTTP response ends, wait for run_asgi's background tasks
|
|
self.asgi_done.wait(self.wait_time)
|
|
self.loop.call_soon_threadsafe(asgi_task.cancel)
|
|
yield b""
|