22
22
from .h3 import H3Protocol
23
23
from ..config import Config
24
24
from ..events import Closed , Event , RawData
25
- from ..typing import AppWrapper , TaskGroup , WorkerContext
25
+ from ..typing import AppWrapper , TaskGroup , WorkerContext , Timer
26
26
27
27
28
28
class QuicProtocol :
@@ -40,6 +40,7 @@ def __init__(
40
40
self .context = context
41
41
self .connections : Dict [bytes , QuicConnection ] = {}
42
42
self .http_connections : Dict [QuicConnection , H3Protocol ] = {}
43
+ self .timers : Dict [QuicConnection , Timer ] = {}
43
44
self .send = send
44
45
self .server = server
45
46
self .task_group = task_group
@@ -82,10 +83,12 @@ async def handle(self, event: Event) -> None:
82
83
)
83
84
self .connections [header .destination_cid ] = connection
84
85
self .connections [connection .host_cid ] = connection
86
+ # This partial() needs python >= 3.8
87
+ self .timers [connection ] = self .task_group .create_timer (partial (self ._timeout , connection ))
85
88
86
89
if connection is not None :
87
90
connection .receive_datagram (event .data , event .address , now = self .context .time ())
88
- await self ._handle_events (connection , event . address )
91
+ await self ._wake_up_timer (connection )
89
92
elif isinstance (event , Closed ):
90
93
pass
91
94
@@ -99,7 +102,16 @@ async def _handle_events(
99
102
event = connection .next_event ()
100
103
while event is not None :
101
104
if isinstance (event , ConnectionTerminated ):
102
- pass
105
+ await self .timers [connection ].stop ()
106
+ del self .timers [connection ]
107
+ # XXXRTH This is not the speediest! Better would be tracking
108
+ # assigned ids in a set.
109
+ prune = []
110
+ for tcid , tconn in self .connections .items ():
111
+ if tconn == connection :
112
+ prune .append (tcid )
113
+ for tcid in prune :
114
+ del self .connections [tcid ]
103
115
elif isinstance (event , ProtocolNegotiated ):
104
116
self .http_connections [connection ] = H3Protocol (
105
117
self .app ,
@@ -109,7 +121,7 @@ async def _handle_events(
109
121
client ,
110
122
self .server ,
111
123
connection ,
112
- partial (self .send_all , connection ),
124
+ partial (self ._wake_up_timer , connection ),
113
125
)
114
126
elif isinstance (event , ConnectionIdIssued ):
115
127
self .connections [event .connection_id ] = connection
@@ -121,15 +133,20 @@ async def _handle_events(
121
133
122
134
event = connection .next_event ()
123
135
136
+ async def _wake_up_timer (self , connection : QuicConnection ):
137
+ # When new output is send, or new input is received, we
138
+ # fire the timer right away so we update our state.
139
+ timer = self .timers .get (connection )
140
+ if timer is not None :
141
+ await timer .schedule (0.0 )
142
+
143
+ async def _timeout (self , connection : QuicConnection ):
144
+ now = self .context .time ()
145
+ when = connection .get_timer ()
146
+ if when is not None and now > when :
147
+ connection .handle_timer (now )
148
+ await self ._handle_events (connection , None )
124
149
await self .send_all (connection )
125
-
126
- timer = connection .get_timer ()
150
+ timer = self .timers .get (connection )
127
151
if timer is not None :
128
- self .task_group .spawn (self ._handle_timer , timer , connection )
129
-
130
- async def _handle_timer (self , timer : float , connection : QuicConnection ) -> None :
131
- wait = max (0 , timer - self .context .time ())
132
- await self .context .sleep (wait )
133
- if connection ._close_at is not None :
134
- connection .handle_timer (now = self .context .time ())
135
- await self ._handle_events (connection , None )
152
+ await timer .schedule (connection .get_timer ())
0 commit comments