Skip to content

Commit 8f2f048

Browse files
Ensure writer is always reset on completion (#7815)
1 parent 366ba40 commit 8f2f048

File tree

4 files changed

+62
-34
lines changed

4 files changed

+62
-34
lines changed

CHANGES/7815.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed an issue where the client could go into an infinite loop. -- by :user:`Dreamsorcerer`

aiohttp/client_reqrep.py

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,13 @@
5656
reify,
5757
set_result,
5858
)
59-
from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, StreamWriter
59+
from .http import (
60+
SERVER_SOFTWARE,
61+
HttpVersion,
62+
HttpVersion10,
63+
HttpVersion11,
64+
StreamWriter,
65+
)
6066
from .log import client_logger
6167
from .streams import StreamReader
6268
from .typedefs import (
@@ -178,7 +184,7 @@ class ClientRequest:
178184
auth = None
179185
response = None
180186

181-
_writer = None # async task for streaming data
187+
__writer = None # async task for streaming data
182188
_continue = None # waiter future for '100 Continue' response
183189

184190
# N.B.
@@ -265,6 +271,21 @@ def __init__(
265271
traces = []
266272
self._traces = traces
267273

274+
def __reset_writer(self, _: object = None) -> None:
275+
self.__writer = None
276+
277+
@property
278+
def _writer(self) -> Optional["asyncio.Task[None]"]:
279+
return self.__writer
280+
281+
@_writer.setter
282+
def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
283+
if self.__writer is not None:
284+
self.__writer.remove_done_callback(self.__reset_writer)
285+
self.__writer = writer
286+
if writer is not None:
287+
writer.add_done_callback(self.__reset_writer)
288+
268289
def is_ssl(self) -> bool:
269290
return self.url.scheme in ("https", "wss")
270291

@@ -563,8 +584,6 @@ async def write_bytes(
563584
else:
564585
await writer.write_eof()
565586
protocol.start_timeout()
566-
finally:
567-
self._writer = None
568587

569588
async def send(self, conn: "Connection") -> "ClientResponse":
570589
# Specify request target:
@@ -649,16 +668,14 @@ async def send(self, conn: "Connection") -> "ClientResponse":
649668

650669
async def close(self) -> None:
651670
if self._writer is not None:
652-
try:
653-
with contextlib.suppress(asyncio.CancelledError):
654-
await self._writer
655-
finally:
656-
self._writer = None
671+
with contextlib.suppress(asyncio.CancelledError):
672+
await self._writer
657673

658674
def terminate(self) -> None:
659675
if self._writer is not None:
660676
if not self.loop.is_closed():
661677
self._writer.cancel()
678+
self._writer.remove_done_callback(self.__reset_writer)
662679
self._writer = None
663680

664681
async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None:
@@ -677,9 +694,9 @@ class ClientResponse(HeadersMixin):
677694
# but will be set by the start() method.
678695
# As the end user will likely never see the None values, we cheat the types below.
679696
# from the Status-Line of the response
680-
version = None # HTTP-Version
697+
version: Optional[HttpVersion] = None # HTTP-Version
681698
status: int = None # type: ignore[assignment] # Status-Code
682-
reason = None # Reason-Phrase
699+
reason: Optional[str] = None # Reason-Phrase
683700

684701
content: StreamReader = None # type: ignore[assignment] # Payload stream
685702
_headers: CIMultiDictProxy[str] = None # type: ignore[assignment]
@@ -691,6 +708,7 @@ class ClientResponse(HeadersMixin):
691708
# post-init stage allows to not change ctor signature
692709
_closed = True # to allow __del__ for non-initialized properly response
693710
_released = False
711+
__writer = None
694712

695713
def __init__(
696714
self,
@@ -737,6 +755,21 @@ def __init__(
737755
if loop.get_debug():
738756
self._source_traceback = traceback.extract_stack(sys._getframe(1))
739757

758+
def __reset_writer(self, _: object = None) -> None:
759+
self.__writer = None
760+
761+
@property
762+
def _writer(self) -> Optional["asyncio.Task[None]"]:
763+
return self.__writer
764+
765+
@_writer.setter
766+
def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
767+
if self.__writer is not None:
768+
self.__writer.remove_done_callback(self.__reset_writer)
769+
self.__writer = writer
770+
if writer is not None:
771+
writer.add_done_callback(self.__reset_writer)
772+
740773
@reify
741774
def url(self) -> URL:
742775
return self._url
@@ -797,7 +830,7 @@ def __repr__(self) -> str:
797830
"ascii", "backslashreplace"
798831
).decode("ascii")
799832
else:
800-
ascii_encodable_reason = self.reason
833+
ascii_encodable_reason = "None"
801834
print(
802835
"<ClientResponse({}) [{} {}]>".format(
803836
ascii_encodable_url, self.status, ascii_encodable_reason
@@ -978,18 +1011,12 @@ def _release_connection(self) -> None:
9781011

9791012
async def _wait_released(self) -> None:
9801013
if self._writer is not None:
981-
try:
982-
await self._writer
983-
finally:
984-
self._writer = None
1014+
await self._writer
9851015
self._release_connection()
9861016

9871017
def _cleanup_writer(self) -> None:
9881018
if self._writer is not None:
989-
if self._writer.done():
990-
self._writer = None
991-
else:
992-
self._writer.cancel()
1019+
self._writer.cancel()
9931020
self._session = None
9941021

9951022
def _notify_content(self) -> None:
@@ -1001,10 +1028,7 @@ def _notify_content(self) -> None:
10011028

10021029
async def wait_for_close(self) -> None:
10031030
if self._writer is not None:
1004-
try:
1005-
await self._writer
1006-
finally:
1007-
self._writer = None
1031+
await self._writer
10081032
self.release()
10091033

10101034
async def read(self) -> bytes:

tests/test_client_response.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import gc
55
import sys
66
from json import JSONDecodeError
7-
from typing import Any
7+
from typing import Any, Callable
88
from unittest import mock
99

1010
import pytest
@@ -22,6 +22,9 @@ class WriterMock(mock.AsyncMock):
2222
def __await__(self) -> None:
2323
return self().__await__()
2424

25+
def add_done_callback(self, cb: Callable[[], None]) -> None:
26+
cb()
27+
2528
def done(self) -> bool:
2629
return True
2730

tests/test_proxy.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def test_proxy_server_hostname_default(self, ClientRequestMock) -> None:
199199
"get",
200200
URL("http://proxy.example.com"),
201201
request_info=mock.Mock(),
202-
writer=mock.Mock(),
202+
writer=None,
203203
continue100=None,
204204
timer=TimerNoop(),
205205
traces=[],
@@ -261,7 +261,7 @@ def test_proxy_server_hostname_override(self, ClientRequestMock) -> None:
261261
"get",
262262
URL("http://proxy.example.com"),
263263
request_info=mock.Mock(),
264-
writer=mock.Mock(),
264+
writer=None,
265265
continue100=None,
266266
timer=TimerNoop(),
267267
traces=[],
@@ -323,7 +323,7 @@ def test_https_connect(self, ClientRequestMock: Any) -> None:
323323
"get",
324324
URL("http://proxy.example.com"),
325325
request_info=mock.Mock(),
326-
writer=mock.Mock(),
326+
writer=None,
327327
continue100=None,
328328
timer=TimerNoop(),
329329
traces=[],
@@ -383,7 +383,7 @@ def test_https_connect_certificate_error(self, ClientRequestMock: Any) -> None:
383383
"get",
384384
URL("http://proxy.example.com"),
385385
request_info=mock.Mock(),
386-
writer=mock.Mock(),
386+
writer=None,
387387
continue100=None,
388388
timer=TimerNoop(),
389389
traces=[],
@@ -437,7 +437,7 @@ def test_https_connect_ssl_error(self, ClientRequestMock: Any) -> None:
437437
"get",
438438
URL("http://proxy.example.com"),
439439
request_info=mock.Mock(),
440-
writer=mock.Mock(),
440+
writer=None,
441441
continue100=None,
442442
timer=TimerNoop(),
443443
traces=[],
@@ -493,7 +493,7 @@ def test_https_connect_http_proxy_error(self, ClientRequestMock: Any) -> None:
493493
"get",
494494
URL("http://proxy.example.com"),
495495
request_info=mock.Mock(),
496-
writer=mock.Mock(),
496+
writer=None,
497497
continue100=None,
498498
timer=TimerNoop(),
499499
traces=[],
@@ -552,7 +552,7 @@ def test_https_connect_resp_start_error(self, ClientRequestMock: Any) -> None:
552552
"get",
553553
URL("http://proxy.example.com"),
554554
request_info=mock.Mock(),
555-
writer=mock.Mock(),
555+
writer=None,
556556
continue100=None,
557557
timer=TimerNoop(),
558558
traces=[],
@@ -663,7 +663,7 @@ def test_https_connect_pass_ssl_context(self, ClientRequestMock: Any) -> None:
663663
"get",
664664
URL("http://proxy.example.com"),
665665
request_info=mock.Mock(),
666-
writer=mock.Mock(),
666+
writer=None,
667667
continue100=None,
668668
timer=TimerNoop(),
669669
traces=[],
@@ -734,7 +734,7 @@ def test_https_auth(self, ClientRequestMock: Any) -> None:
734734
"get",
735735
URL("http://proxy.example.com"),
736736
request_info=mock.Mock(),
737-
writer=mock.Mock(),
737+
writer=None,
738738
continue100=None,
739739
timer=TimerNoop(),
740740
traces=[],

0 commit comments

Comments
 (0)