diff --git a/quic/safe_stream.go b/quic/safe_stream.go index db347aac..a1575b46 100644 --- a/quic/safe_stream.go +++ b/quic/safe_stream.go @@ -4,6 +4,7 @@ import ( "errors" "net" "sync" + "sync/atomic" "time" "github.com/quic-go/quic-go" @@ -19,6 +20,7 @@ type SafeStreamCloser struct { stream quic.Stream writeTimeout time.Duration log *zerolog.Logger + closing atomic.Bool } func NewSafeStreamCloser(stream quic.Stream, writeTimeout time.Duration, log *zerolog.Logger) *SafeStreamCloser { @@ -44,27 +46,35 @@ func (s *SafeStreamCloser) Write(p []byte) (n int, err error) { } nBytes, err := s.stream.Write(p) if err != nil { - s.handleTimeout(err) + s.handleWriteError(err) } return nBytes, err } // Handles the timeout error in case it happened, by canceling the stream write. -func (s *SafeStreamCloser) handleTimeout(err error) { +func (s *SafeStreamCloser) handleWriteError(err error) { + // If we are closing the stream we just ignore any write error. + if s.closing.Load() { + return + } var netErr net.Error if errors.As(err, &netErr) { if netErr.Timeout() { - // We don't need to log if what cause the timeout was `no network activity`. + // We don't need to log if what cause the timeout was no network activity. if !errors.Is(netErr, &idleTimeoutError) { s.log.Error().Err(netErr).Msg("Closing quic stream due to timeout while writing") } + // We need to explicitly cancel the write so that it frees all buffers. s.stream.CancelWrite(0) } } } func (s *SafeStreamCloser) Close() error { + // Set this stream to a closing state. + s.closing.Store(true) + // Make sure a possible writer does not block the lock forever. We need it, so we can close the writer // side of the stream safely. _ = s.stream.SetWriteDeadline(time.Now())