Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 12 additions & 18 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class StarryPyServer:
"""
def __init__(self, reader, writer, config, factory):
logger.debug("Initializing connection.")
self._reader = reader # read packets from client
self._writer = writer # writes packets to client
self._reader = ZstdFrameReader(reader, Direction.TO_SERVER) # read packets from client
self._writer = ZstdFrameWriter(writer) # writes packets to client
self._client_reader = None # read packets from server (acting as client)
self._client_writer = None # write packets to server
self.factory = factory
Expand All @@ -48,17 +48,13 @@ def __init__(self, reader, writer, config, factory):
self._client_read_future = None
self._server_write_future = None
self._client_write_future = None
self._expect_server_loop_death = False
logger.info("Received connection from {}".format(self.client_ip))

def start_zstd(self):
self._reader = ZstdFrameReader(self._reader, Direction.TO_SERVER)
self._client_reader= ZstdFrameReader(self._client_reader, Direction.TO_CLIENT)
self._writer = ZstdFrameWriter(self._writer, skip_packets=1)
self._client_writer = ZstdFrameWriter(self._client_writer)
self._expect_server_loop_death = True
self._server_loop_future.cancel()
self._server_loop_future = asyncio.create_task(self.server_loop())
self._reader.enable_zstd()
self._client_reader.enable_zstd()
self._writer.enable_zstd(skip_packets=1) # skip this packet
self._client_writer.enable_zstd()
logger.info("Switched to zstd")


Expand Down Expand Up @@ -95,12 +91,8 @@ async def server_loop(self):
"{}: {}".format(err.__class__.__name__, err))
logger.error("Error details and traceback: {}".format(traceback.format_exc()))
finally:
if not self._expect_server_loop_death:
logger.info("Server loop ended.")
self.die()
else:
logger.info("Restarting server loop for switch to zstd.")
self._expect_server_loop_death = False
logger.info("Server loop ended.")
self.die()

async def client_loop(self):
"""
Expand All @@ -109,9 +101,11 @@ async def client_loop(self):

:return:
"""
(self._client_reader, self._client_writer) = \
await asyncio.open_connection(self.config['upstream_host'],
(reader, writer) = await asyncio.open_connection(self.config['upstream_host'],
self.config['upstream_port'])

self._client_reader = ZstdFrameReader(reader, Direction.TO_CLIENT)
self._client_writer = ZstdFrameWriter(writer)

try:
while True:
Expand Down
17 changes: 12 additions & 5 deletions zstd_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ def __init__(self, reader: asyncio.StreamReader, direction: Direction):
self.decompressor = zstd.ZstdDecompressor().stream_writer(self.outputbuffer)
self.raw_reader = reader
self.direction = direction
self.zstd_enabled = False

def enable_zstd(self):
self.zstd_enabled = True

async def readexactly(self, count):
# print(f"Reading exactly {count} bytes")
Expand All @@ -31,11 +35,14 @@ async def read_from_network(self, target_count):
# print(f"Read {len(chunk)} bytes from network")
if not chunk:
raise asyncio.CancelledError("Connection closed")
try:
self.decompressor.write(chunk)
except zstd.ZstdError:
print("Zstd error, dropping connection")
raise asyncio.CancelledError("Error in compressed data stream!")
if not self.zstd_enabled:
self.outputbuffer.write(chunk)
else:
try:
self.decompressor.write(chunk)
except zstd.ZstdError:
print("Zstd error, dropping connection")
raise asyncio.CancelledError("Error in compressed data stream!")

class NonSeekableMemoryStream(io.RawIOBase):
def __init__(self):
Expand Down
11 changes: 10 additions & 1 deletion zstd_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
import zstandard as zstd

class ZstdFrameWriter:
def __init__(self, raw_writer: asyncio.StreamWriter, skip_packets=0):
def __init__(self, raw_writer: asyncio.StreamWriter):
self.compressor = zstd.ZstdCompressor()
self.raw_writer = raw_writer
self.skip_packets = 0
self.zstd_enabled = False

def enable_zstd(self, skip_packets=0):
self.zstd_enabled = True
self.skip_packets = skip_packets

async def drain(self):
Expand All @@ -16,6 +21,10 @@ def close(self):
self.compressor = None

def write(self, data):

if not self.zstd_enabled:
self.raw_writer.write(data)
return

if self.skip_packets > 0:
self.skip_packets -= 1
Expand Down