Skip to content

Commit

Permalink
feat: support Django asgi middleware (#625)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche authored Oct 6, 2022
1 parent 81fb6c2 commit f52b3aa
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 24 deletions.
19 changes: 8 additions & 11 deletions google/cloud/logging_v2/handlers/middleware/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,20 @@ def _get_django_request():
return getattr(_thread_locals, "request", None)


try:
from django.utils.deprecation import MiddlewareMixin
except ImportError: # pragma: NO COVER
MiddlewareMixin = object


class RequestMiddleware(MiddlewareMixin):
def RequestMiddleware(get_response):
"""Saves the request in thread local"""

def __init__(self, get_response):
self.get_response = get_response

def process_request(self, request):
def middleware(request):
"""Called on each request, before Django decides which view to execute.
Args:
request(django.http.request.HttpRequest):
Django http request.
"""
_thread_locals.request = request
if get_response:
return get_response(request)
else:
return None

return middleware
2 changes: 1 addition & 1 deletion tests/environment
17 changes: 10 additions & 7 deletions tests/unit/handlers/middleware/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,29 @@ def _make_one(self, *args, **kw):

return self._get_target_class()(*args, **kw)

def _mock_get_response(self, req):
return req

def test_process_request(self):
from django.test import RequestFactory
from google.cloud.logging_v2.handlers.middleware import request

middleware = self._make_one()
mock_request = RequestFactory().get("/")
middleware.process_request(mock_request)
middleware(mock_request)

django_request = request._get_django_request()
self.assertEqual(django_request, mock_request)

def test_can_instantiate_middleware_without_kwargs(self):
handler = mock.Mock()
middleware = self._make_one(handler)
self.assertEqual(middleware.get_response, handler)
middleware = self._make_one(self._mock_get_response)
mock_request = "test_req"
self.assertEqual(middleware(mock_request), mock_request)

def test_can_instantiate_middleware_with_kwargs(self):
handler = mock.Mock()
middleware = self._make_one(get_response=handler)
self.assertEqual(middleware.get_response, handler)
middleware = self._make_one(get_response=self._mock_get_response)
mock_request = "test_req"
self.assertEqual(middleware(mock_request), mock_request)


class Test__get_django_request(DjangoBase):
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/handlers/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_no_context_header(self):
django_request = RequestFactory().get("/")

middleware = request.RequestMiddleware(None)
middleware.process_request(django_request)
middleware(django_request)
http_request, trace_id, span_id, sampled = self._call_fut()

self.assertEqual(http_request["requestMethod"], "GET")
Expand All @@ -175,7 +175,7 @@ def test_xcloud_header(self):
)

middleware = request.RequestMiddleware(None)
middleware.process_request(django_request)
middleware(django_request)
http_request, trace_id, span_id, sampled = self._call_fut()

self.assertEqual(trace_id, expected_trace_id)
Expand All @@ -195,7 +195,7 @@ def test_traceparent_header(self):
django_request = RequestFactory().get("/", **{django_trace_header: header})

middleware = request.RequestMiddleware(None)
middleware.process_request(django_request)
middleware(django_request)
http_request, trace_id, span_id, sampled = self._call_fut()

self.assertEqual(trace_id, expected_trace_id)
Expand All @@ -222,7 +222,7 @@ def test_http_request_populated(self):
django_request.read()

middleware = request.RequestMiddleware(None)
middleware.process_request(django_request)
middleware(django_request)
http_request, *_ = self._call_fut()
self.assertEqual(http_request["requestMethod"], "PUT")
self.assertEqual(http_request["requestUrl"], expected_path)
Expand All @@ -236,7 +236,7 @@ def test_http_request_sparse(self):
expected_path = "http://testserver/123"
django_request = RequestFactory().put(expected_path)
middleware = request.RequestMiddleware(None)
middleware.process_request(django_request)
middleware(django_request)
http_request, *_ = self._call_fut()
self.assertEqual(http_request["requestMethod"], "PUT")
self.assertEqual(http_request["requestUrl"], expected_path)
Expand Down

0 comments on commit f52b3aa

Please sign in to comment.