diff --git a/server.py b/server.py index 8478ef8..a3fe3d0 100644 --- a/server.py +++ b/server.py @@ -33,8 +33,8 @@ class StarryPyServer: """ def __init__(self, reader, writer, config, factory): logger.debug("Initializing connection.") - self._reader = reader # read packets from client - self._writer = writer # writes packets to client + self._reader = ZstdFrameReader(reader, Direction.TO_SERVER) # read packets from client + self._writer = ZstdFrameWriter(writer) # writes packets to client self._client_reader = None # read packets from server (acting as client) self._client_writer = None # write packets to server self.factory = factory @@ -48,17 +48,13 @@ def __init__(self, reader, writer, config, factory): self._client_read_future = None self._server_write_future = None self._client_write_future = None - self._expect_server_loop_death = False logger.info("Received connection from {}".format(self.client_ip)) def start_zstd(self): - self._reader = ZstdFrameReader(self._reader, Direction.TO_SERVER) - self._client_reader= ZstdFrameReader(self._client_reader, Direction.TO_CLIENT) - self._writer = ZstdFrameWriter(self._writer, skip_packets=1) - self._client_writer = ZstdFrameWriter(self._client_writer) - self._expect_server_loop_death = True - self._server_loop_future.cancel() - self._server_loop_future = asyncio.create_task(self.server_loop()) + self._reader.enable_zstd() + self._client_reader.enable_zstd() + self._writer.enable_zstd(skip_packets=1) # skip this packet + self._client_writer.enable_zstd() logger.info("Switched to zstd") @@ -95,12 +91,8 @@ async def server_loop(self): "{}: {}".format(err.__class__.__name__, err)) logger.error("Error details and traceback: {}".format(traceback.format_exc())) finally: - if not self._expect_server_loop_death: - logger.info("Server loop ended.") - self.die() - else: - logger.info("Restarting server loop for switch to zstd.") - self._expect_server_loop_death = False + logger.info("Server loop ended.") + self.die() async def client_loop(self): """ @@ -109,9 +101,11 @@ async def client_loop(self): :return: """ - (self._client_reader, self._client_writer) = \ - await asyncio.open_connection(self.config['upstream_host'], + (reader, writer) = await asyncio.open_connection(self.config['upstream_host'], self.config['upstream_port']) + + self._client_reader = ZstdFrameReader(reader, Direction.TO_CLIENT) + self._client_writer = ZstdFrameWriter(writer) try: while True: diff --git a/zstd_reader.py b/zstd_reader.py index eb2fa71..070c67b 100644 --- a/zstd_reader.py +++ b/zstd_reader.py @@ -11,6 +11,10 @@ def __init__(self, reader: asyncio.StreamReader, direction: Direction): self.decompressor = zstd.ZstdDecompressor().stream_writer(self.outputbuffer) self.raw_reader = reader self.direction = direction + self.zstd_enabled = False + + def enable_zstd(self): + self.zstd_enabled = True async def readexactly(self, count): # print(f"Reading exactly {count} bytes") @@ -31,11 +35,14 @@ async def read_from_network(self, target_count): # print(f"Read {len(chunk)} bytes from network") if not chunk: raise asyncio.CancelledError("Connection closed") - try: - self.decompressor.write(chunk) - except zstd.ZstdError: - print("Zstd error, dropping connection") - raise asyncio.CancelledError("Error in compressed data stream!") + if not self.zstd_enabled: + self.outputbuffer.write(chunk) + else: + try: + self.decompressor.write(chunk) + except zstd.ZstdError: + print("Zstd error, dropping connection") + raise asyncio.CancelledError("Error in compressed data stream!") class NonSeekableMemoryStream(io.RawIOBase): def __init__(self): diff --git a/zstd_writer.py b/zstd_writer.py index e6c34b5..d2416eb 100644 --- a/zstd_writer.py +++ b/zstd_writer.py @@ -3,9 +3,14 @@ import zstandard as zstd class ZstdFrameWriter: - def __init__(self, raw_writer: asyncio.StreamWriter, skip_packets=0): + def __init__(self, raw_writer: asyncio.StreamWriter): self.compressor = zstd.ZstdCompressor() self.raw_writer = raw_writer + self.skip_packets = 0 + self.zstd_enabled = False + + def enable_zstd(self, skip_packets=0): + self.zstd_enabled = True self.skip_packets = skip_packets async def drain(self): @@ -16,6 +21,10 @@ def close(self): self.compressor = None def write(self, data): + + if not self.zstd_enabled: + self.raw_writer.write(data) + return if self.skip_packets > 0: self.skip_packets -= 1