fix(protocols): guard against zero/malformed length fields in binary protocol parsers

MongoDB had the same infinite-loop bug as MSSQL (msg_len=0 → buffer never
shrinks in while loop). Postgres, MySQL, and MQTT had related length-field
issues (stuck state, resource exhaustion, overlong remaining-length).

Also fixes an existing MongoDB _op_reply struct.pack format bug (extra 'q'
specifier caused struct.error on any OP_QUERY response).

Adds 53 regression + protocol boundary tests across MSSQL, MongoDB,
Postgres, MySQL, and MQTT, including a _run_with_timeout threading harness
to catch infinite loops and @pytest.mark.fuzz hypothesis tests for each.
This commit is contained in:
2026-04-12 01:01:13 -04:00
parent 65d585569b
commit d63e396410
10 changed files with 894 additions and 2 deletions

View File

@@ -35,13 +35,13 @@ def _op_reply(request_id: int, doc: bytes) -> bytes:
# OP_REPLY header: total_len(4), req_id(4), response_to(4), opcode(4)=1,
# flags(4), cursor_id(8), starting_from(4), number_returned(4), docs
header = struct.pack(
"<iiiiiqqii",
"<iiiiiqii",
16 + 20 + len(doc), # total length
0, # request id
request_id, # response to
1, # OP_REPLY
0, # flags
0, # cursor id
0, # cursor id (int64)
0, # starting from
1, # number returned
)
@@ -81,6 +81,10 @@ class MongoDBProtocol(asyncio.Protocol):
self._buf += data
while len(self._buf) >= 16:
msg_len = struct.unpack("<I", self._buf[:4])[0]
if msg_len < 16 or msg_len > 48 * 1024 * 1024:
self._transport.close()
self._buf = b""
return
if len(self._buf) < msg_len:
break
msg = self._buf[:msg_len]

View File

@@ -191,6 +191,10 @@ class MQTTProtocol(asyncio.Protocol):
remaining = 0
multiplier = 1
while pos < len(self._buf):
if pos > 4: # MQTT spec: max 4 bytes for remaining length
self._transport.close()
self._buf = b""
return
byte = self._buf[pos]
remaining += (byte & 0x7f) * multiplier
multiplier *= 128

View File

@@ -67,6 +67,10 @@ class MySQLProtocol(asyncio.Protocol):
# MySQL packets: 3-byte length + 1-byte seq + payload
while len(self._buf) >= 4:
length = struct.unpack("<I", self._buf[:3] + b"\x00")[0]
if length > 1024 * 1024:
self._transport.close()
self._buf = b""
return
if len(self._buf) < 4 + length:
break
payload = self._buf[4:4 + length]

View File

@@ -49,6 +49,10 @@ class PostgresProtocol(asyncio.Protocol):
if len(self._buf) < 4:
return
msg_len = struct.unpack(">I", self._buf[:4])[0]
if msg_len < 8 or msg_len > 10_000:
self._transport.close()
self._buf = b""
return
if len(self._buf) < msg_len:
return
msg = self._buf[:msg_len]
@@ -59,6 +63,10 @@ class PostgresProtocol(asyncio.Protocol):
return
msg_type = chr(self._buf[0])
msg_len = struct.unpack(">I", self._buf[1:5])[0]
if msg_len < 4 or msg_len > 10_000:
self._transport.close()
self._buf = b""
return
if len(self._buf) < msg_len + 1:
return
payload = self._buf[5:msg_len + 1]

View File

@@ -0,0 +1,47 @@
"""
Shared helpers for binary-protocol service tests.
"""
import os
import threading
from types import ModuleType
from unittest.mock import MagicMock
import pytest
from hypothesis import HealthCheck
_FUZZ_SETTINGS = dict(
max_examples=int(os.environ.get("HYPOTHESIS_MAX_EXAMPLES", "200")),
deadline=2000,
suppress_health_check=[HealthCheck.function_scoped_fixture],
)
def make_fake_decnet_logging() -> ModuleType:
mod = ModuleType("decnet_logging")
mod.syslog_line = MagicMock(return_value="")
mod.write_syslog_file = MagicMock()
mod.forward_syslog = MagicMock()
mod.SEVERITY_WARNING = 4
mod.SEVERITY_INFO = 6
return mod
def run_with_timeout(fn, *args, timeout: float = 2.0) -> None:
"""Run fn(*args) in a daemon thread. pytest.fail if it doesn't return in time."""
exc_box: list[BaseException] = []
def _target():
try:
fn(*args)
except Exception as e:
exc_box.append(e)
t = threading.Thread(target=_target, daemon=True)
t.start()
t.join(timeout)
if t.is_alive():
pytest.fail(f"data_received hung for >{timeout}s — likely infinite loop")
if exc_box:
raise exc_box[0]

View File

@@ -0,0 +1,161 @@
"""
Tests for templates/mongodb/server.py
Covers the MongoDB wire-protocol (OP_MSG / OP_QUERY) happy path and regression
tests for the zero-length msg_len infinite-loop bug and oversized msg_len.
"""
import importlib.util
import struct
import sys
from unittest.mock import MagicMock
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from .conftest import _FUZZ_SETTINGS, make_fake_decnet_logging, run_with_timeout
# ── Helpers ───────────────────────────────────────────────────────────────────
def _load_mongodb():
for key in list(sys.modules):
if key in ("mongodb_server", "decnet_logging"):
del sys.modules[key]
sys.modules["decnet_logging"] = make_fake_decnet_logging()
spec = importlib.util.spec_from_file_location("mongodb_server", "templates/mongodb/server.py")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
def _make_protocol(mod):
proto = mod.MongoDBProtocol()
transport = MagicMock()
written: list[bytes] = []
transport.write.side_effect = written.append
proto.connection_made(transport)
return proto, transport, written
def _minimal_bson() -> bytes:
return b"\x05\x00\x00\x00\x00" # empty document
def _op_msg_packet(request_id: int = 1) -> bytes:
"""Build a valid OP_MSG with an empty BSON body."""
flag_bits = struct.pack("<I", 0)
section = b"\x00" + _minimal_bson()
body = flag_bits + section
total = 16 + len(body)
header = struct.pack("<iiii", total, request_id, 0, 2013)
return header + body
def _op_query_packet(request_id: int = 2) -> bytes:
"""Build a minimal OP_QUERY."""
flags = struct.pack("<I", 0)
coll = b"admin.$cmd\x00"
skip = struct.pack("<I", 0)
ret = struct.pack("<I", 1)
query = _minimal_bson()
body = flags + coll + skip + ret + query
total = 16 + len(body)
header = struct.pack("<iiii", total, request_id, 0, 2004)
return header + body
@pytest.fixture
def mongodb_mod():
return _load_mongodb()
# ── Happy path ────────────────────────────────────────────────────────────────
def test_op_msg_returns_response(mongodb_mod):
proto, _, written = _make_protocol(mongodb_mod)
proto.data_received(_op_msg_packet())
assert written, "expected a response to OP_MSG"
def test_op_msg_response_opcode_is_2013(mongodb_mod):
proto, _, written = _make_protocol(mongodb_mod)
proto.data_received(_op_msg_packet())
resp = b"".join(written)
assert len(resp) >= 16
opcode = struct.unpack("<i", resp[12:16])[0]
assert opcode == 2013
def test_op_query_returns_op_reply(mongodb_mod):
proto, _, written = _make_protocol(mongodb_mod)
proto.data_received(_op_query_packet())
resp = b"".join(written)
assert len(resp) >= 16
opcode = struct.unpack("<i", resp[12:16])[0]
assert opcode == 1
def test_partial_header_waits_for_more_data(mongodb_mod):
proto, transport, _ = _make_protocol(mongodb_mod)
proto.data_received(b"\x1a\x00\x00\x00") # only 4 bytes (< 16)
transport.close.assert_not_called()
def test_two_consecutive_messages(mongodb_mod):
proto, _, written = _make_protocol(mongodb_mod)
two = _op_msg_packet(1) + _op_msg_packet(2)
proto.data_received(two)
assert len(written) >= 2
def test_connection_lost_does_not_raise(mongodb_mod):
proto, _, _ = _make_protocol(mongodb_mod)
proto.connection_lost(None)
# ── Regression: malformed msg_len ────────────────────────────────────────────
def test_zero_msg_len_closes(mongodb_mod):
proto, transport, _ = _make_protocol(mongodb_mod)
# msg_len = 0 at bytes [0:4] LE — buffer has 16 bytes so outer while triggers
data = b"\x00\x00\x00\x00" + b"\x00" * 12
run_with_timeout(proto.data_received, data)
transport.close.assert_called()
def test_msg_len_15_closes(mongodb_mod):
proto, transport, _ = _make_protocol(mongodb_mod)
# msg_len = 15 (below 16-byte wire-protocol minimum)
data = struct.pack("<I", 15) + b"\x00" * 12
run_with_timeout(proto.data_received, data)
transport.close.assert_called()
def test_msg_len_over_48mb_closes(mongodb_mod):
proto, transport, _ = _make_protocol(mongodb_mod)
# msg_len = 48MB + 1
big = 48 * 1024 * 1024 + 1
data = struct.pack("<I", big) + b"\x00" * 12
run_with_timeout(proto.data_received, data)
transport.close.assert_called()
def test_msg_len_exactly_48mb_plus1_closes(mongodb_mod):
proto, transport, _ = _make_protocol(mongodb_mod)
# cap is strictly > 48MB, so 48MB+1 must close
data = struct.pack("<I", 48 * 1024 * 1024 + 1) + b"\x00" * 12
run_with_timeout(proto.data_received, data)
transport.close.assert_called()
# ── Fuzz ──────────────────────────────────────────────────────────────────────
@pytest.mark.fuzz
@given(data=st.binary(min_size=0, max_size=512))
@settings(**_FUZZ_SETTINGS)
def test_fuzz_arbitrary_bytes(data):
mod = _load_mongodb()
proto, _, _ = _make_protocol(mod)
run_with_timeout(proto.data_received, data)

View File

@@ -0,0 +1,185 @@
"""
Tests for templates/mqtt/server.py — protocol boundary and fuzz cases.
Focuses on the variable-length remaining-length field (MQTT spec: max 4 bytes).
A 5th continuation byte used to cause the server to get stuck waiting for a
payload it could never receive (remaining = hundreds of MB).
"""
import importlib.util
import struct
import sys
from unittest.mock import MagicMock, patch
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from .conftest import _FUZZ_SETTINGS, make_fake_decnet_logging, run_with_timeout
# ── Helpers ───────────────────────────────────────────────────────────────────
def _load_mqtt():
for key in list(sys.modules):
if key in ("mqtt_server", "decnet_logging"):
del sys.modules[key]
sys.modules["decnet_logging"] = make_fake_decnet_logging()
spec = importlib.util.spec_from_file_location("mqtt_server", "templates/mqtt/server.py")
mod = importlib.util.module_from_spec(spec)
with patch.dict("os.environ", {"MQTT_ACCEPT_ALL": "1", "MQTT_PERSONA": "water_plant"}, clear=False):
spec.loader.exec_module(mod)
return mod
def _make_protocol(mod):
proto = mod.MQTTProtocol()
transport = MagicMock()
written: list[bytes] = []
transport.write.side_effect = written.append
proto.connection_made(transport)
return proto, transport, written
def _connect_packet(client_id: str = "test-client") -> bytes:
"""Build a minimal MQTT CONNECT packet."""
proto_name = b"\x00\x04MQTT"
proto_level = b"\x04" # 3.1.1
flags = b"\x02" # clean session
keepalive = b"\x00\x3c"
cid = client_id.encode()
cid_field = struct.pack(">H", len(cid)) + cid
payload = proto_name + proto_level + flags + keepalive + cid_field
remaining = len(payload)
# single-byte remaining length (works for short payloads)
return bytes([0x10, remaining]) + payload
def _encode_remaining(value: int) -> bytes:
"""Encode a value using MQTT variable-length encoding."""
result = []
while True:
encoded = value % 128
value //= 128
if value > 0:
encoded |= 128
result.append(encoded)
if value == 0:
break
return bytes(result)
@pytest.fixture
def mqtt_mod():
return _load_mqtt()
# ── Happy path ────────────────────────────────────────────────────────────────
def test_connect_returns_connack_accepted(mqtt_mod):
proto, _, written = _make_protocol(mqtt_mod)
proto.data_received(_connect_packet())
resp = b"".join(written)
assert resp[:2] == b"\x20\x02" # CONNACK
assert resp[3:4] == b"\x00" # return code 0 = accepted
def test_connect_sets_auth_flag(mqtt_mod):
proto, _, _ = _make_protocol(mqtt_mod)
proto.data_received(_connect_packet())
assert proto._auth is True
def test_pingreq_returns_pingresp(mqtt_mod):
proto, _, written = _make_protocol(mqtt_mod)
proto.data_received(_connect_packet())
written.clear()
proto.data_received(b"\xc0\x00") # PINGREQ
assert b"\xd0\x00" in b"".join(written)
def test_disconnect_closes_transport(mqtt_mod):
proto, transport, _ = _make_protocol(mqtt_mod)
proto.data_received(_connect_packet())
transport.reset_mock()
proto.data_received(b"\xe0\x00") # DISCONNECT
transport.close.assert_called()
def test_publish_without_auth_closes(mqtt_mod):
proto, transport, _ = _make_protocol(mqtt_mod)
# PUBLISH without prior CONNECT
topic = b"\x00\x04test"
payload = b"hello"
remaining = len(topic) + len(payload)
proto.data_received(bytes([0x30, remaining]) + topic + payload)
transport.close.assert_called()
def test_partial_packet_waits_for_more(mqtt_mod):
proto, transport, _ = _make_protocol(mqtt_mod)
proto.data_received(b"\x10") # just the first byte
transport.close.assert_not_called()
def test_connection_lost_does_not_raise(mqtt_mod):
proto, _, _ = _make_protocol(mqtt_mod)
proto.connection_lost(None)
# ── Regression: overlong remaining-length field ───────────────────────────────
def test_5_continuation_bytes_closes(mqtt_mod):
proto, transport, _ = _make_protocol(mqtt_mod)
# 5 bytes with continuation bit set, then a final byte
# MQTT spec allows max 4 bytes — this must be rejected
data = bytes([0x30, 0x80, 0x80, 0x80, 0x80, 0x01])
run_with_timeout(proto.data_received, data)
transport.close.assert_called()
def test_6_continuation_bytes_closes(mqtt_mod):
proto, transport, _ = _make_protocol(mqtt_mod)
data = bytes([0x30]) + bytes([0x80] * 6) + b"\x01"
run_with_timeout(proto.data_received, data)
transport.close.assert_called()
def test_4_continuation_bytes_is_accepted(mqtt_mod):
proto, transport, _ = _make_protocol(mqtt_mod)
# 4 bytes total for remaining length = max allowed.
# remaining = 0x0FFFFFFF = 268435455 bytes — huge but spec-valid encoding.
# With no data following, it simply returns (incomplete payload) — not closed.
data = bytes([0x30, 0xff, 0xff, 0xff, 0x7f])
run_with_timeout(proto.data_received, data)
transport.close.assert_not_called()
def test_zero_remaining_publish_does_not_close(mqtt_mod):
proto, transport, _ = _make_protocol(mqtt_mod)
proto.data_received(_connect_packet())
transport.reset_mock()
# PUBLISH with remaining=0 is unusual but not a protocol violation
proto.data_received(b"\x30\x00")
transport.close.assert_not_called()
# ── Fuzz ──────────────────────────────────────────────────────────────────────
@pytest.mark.fuzz
@given(data=st.binary(min_size=0, max_size=512))
@settings(**_FUZZ_SETTINGS)
def test_fuzz_unauthenticated(data):
mod = _load_mqtt()
proto, _, _ = _make_protocol(mod)
run_with_timeout(proto.data_received, data)
@pytest.mark.fuzz
@given(data=st.binary(min_size=0, max_size=512))
@settings(**_FUZZ_SETTINGS)
def test_fuzz_after_connect(data):
mod = _load_mqtt()
proto, _, _ = _make_protocol(mod)
proto.data_received(_connect_packet())
run_with_timeout(proto.data_received, data)

View File

@@ -0,0 +1,137 @@
"""
Tests for templates/mssql/server.py
Covers the TDS pre-login / login7 happy path and regression tests for the
zero-length pkt_len infinite-loop bug that was fixed (pkt_len < 8 guard).
"""
import importlib.util
import struct
import sys
from unittest.mock import MagicMock
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from .conftest import _FUZZ_SETTINGS, make_fake_decnet_logging, run_with_timeout
# ── Helpers ───────────────────────────────────────────────────────────────────
def _load_mssql():
for key in list(sys.modules):
if key in ("mssql_server", "decnet_logging"):
del sys.modules[key]
sys.modules["decnet_logging"] = make_fake_decnet_logging()
spec = importlib.util.spec_from_file_location("mssql_server", "templates/mssql/server.py")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
def _make_protocol(mod):
proto = mod.MSSQLProtocol()
transport = MagicMock()
written: list[bytes] = []
transport.write.side_effect = written.append
transport.is_closing.return_value = False
proto.connection_made(transport)
return proto, transport, written
def _tds_header(pkt_type: int, pkt_len: int) -> bytes:
"""Build an 8-byte TDS packet header."""
return struct.pack(">BBHBBBB", pkt_type, 0x01, pkt_len, 0x00, 0x00, 0x01, 0x00)
def _prelogin_packet() -> bytes:
header = _tds_header(0x12, 8)
return header
def _login7_packet() -> bytes:
"""Minimal Login7 with 40-byte payload (username at offset 0, length 0)."""
payload = b"\x00" * 40
pkt_len = 8 + len(payload)
header = _tds_header(0x10, pkt_len)
return header + payload
@pytest.fixture
def mssql_mod():
return _load_mssql()
# ── Happy path ────────────────────────────────────────────────────────────────
def test_prelogin_response_is_tds_type4(mssql_mod):
proto, _, written = _make_protocol(mssql_mod)
proto.data_received(_prelogin_packet())
assert written, "expected a pre-login response"
assert written[0][0] == 0x04
def test_prelogin_response_length_matches_header(mssql_mod):
proto, _, written = _make_protocol(mssql_mod)
proto.data_received(_prelogin_packet())
resp = b"".join(written)
declared_len = struct.unpack(">H", resp[2:4])[0]
assert declared_len == len(resp)
def test_login7_auth_logged_and_closes(mssql_mod):
proto, transport, written = _make_protocol(mssql_mod)
proto.data_received(_prelogin_packet())
written.clear()
proto.data_received(_login7_packet())
transport.close.assert_called()
# error packet must be present
assert any(b"\xaa" in chunk for chunk in written)
def test_partial_header_waits_for_more_data(mssql_mod):
proto, transport, _ = _make_protocol(mssql_mod)
proto.data_received(b"\x12\x01")
transport.close.assert_not_called()
def test_connection_lost_does_not_raise(mssql_mod):
proto, _, _ = _make_protocol(mssql_mod)
proto.connection_lost(None)
# ── Regression: zero / small pkt_len ─────────────────────────────────────────
def test_zero_pkt_len_closes(mssql_mod):
proto, transport, _ = _make_protocol(mssql_mod)
# pkt_len = 0x0000 at bytes [2:4]
data = b"\x12\x01\x00\x00\x00\x00\x01\x00"
run_with_timeout(proto.data_received, data)
transport.close.assert_called()
def test_pkt_len_7_closes(mssql_mod):
proto, transport, _ = _make_protocol(mssql_mod)
# pkt_len = 7 (< 8 minimum)
data = _tds_header(0x12, 7) + b"\x00"
run_with_timeout(proto.data_received, data)
transport.close.assert_called()
def test_pkt_len_1_closes(mssql_mod):
proto, transport, _ = _make_protocol(mssql_mod)
data = _tds_header(0x12, 1) + b"\x00" * 7
run_with_timeout(proto.data_received, data)
transport.close.assert_called()
# ── Fuzz ──────────────────────────────────────────────────────────────────────
@pytest.mark.fuzz
@given(data=st.binary(min_size=0, max_size=512))
@settings(**_FUZZ_SETTINGS)
def test_fuzz_arbitrary_bytes(data):
mod = _load_mssql()
proto, _, _ = _make_protocol(mod)
run_with_timeout(proto.data_received, data)

View File

@@ -0,0 +1,153 @@
"""
Tests for templates/mysql/server.py
Covers the MySQL handshake happy path and regression tests for oversized
length fields that could cause huge buffer allocations.
"""
import importlib.util
import struct
import sys
from unittest.mock import MagicMock, patch
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from .conftest import _FUZZ_SETTINGS, make_fake_decnet_logging, run_with_timeout
# ── Helpers ───────────────────────────────────────────────────────────────────
def _load_mysql():
for key in list(sys.modules):
if key in ("mysql_server", "decnet_logging"):
del sys.modules[key]
sys.modules["decnet_logging"] = make_fake_decnet_logging()
spec = importlib.util.spec_from_file_location("mysql_server", "templates/mysql/server.py")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
def _make_protocol(mod):
proto = mod.MySQLProtocol()
transport = MagicMock()
written: list[bytes] = []
transport.write.side_effect = written.append
proto.connection_made(transport)
written.clear() # clear the greeting sent on connect
return proto, transport, written
def _make_packet(payload: bytes, seq: int = 1) -> bytes:
length = len(payload)
return struct.pack("<I", length)[:3] + bytes([seq]) + payload
def _login_packet(username: str = "root") -> bytes:
"""Minimal MySQL client login packet."""
caps = struct.pack("<I", 0x000FA685)
max_pkt = struct.pack("<I", 16777216)
charset = b"\x21"
reserved = b"\x00" * 23
uname = username.encode() + b"\x00"
payload = caps + max_pkt + charset + reserved + uname
return _make_packet(payload, seq=1)
@pytest.fixture
def mysql_mod():
return _load_mysql()
# ── Happy path ────────────────────────────────────────────────────────────────
def test_connection_sends_greeting(mysql_mod):
proto = mysql_mod.MySQLProtocol()
transport = MagicMock()
written: list[bytes] = []
transport.write.side_effect = written.append
proto.connection_made(transport)
greeting = b"".join(written)
assert greeting[4] == 0x0a # protocol v10
assert b"mysql_native_password" in greeting
def test_login_packet_triggers_close(mysql_mod):
proto, transport, _ = _make_protocol(mysql_mod)
proto.data_received(_login_packet())
transport.close.assert_called()
def test_login_packet_returns_access_denied(mysql_mod):
proto, _, written = _make_protocol(mysql_mod)
proto.data_received(_login_packet())
resp = b"".join(written)
assert b"\xff" in resp # error packet marker
def test_login_logs_username():
mod = _load_mysql()
log_mock = sys.modules["decnet_logging"]
proto, _, _ = _make_protocol(mod)
proto.data_received(_login_packet(username="hacker"))
calls_str = str(log_mock.syslog_line.call_args_list)
assert "hacker" in calls_str
def test_empty_payload_packet_does_not_crash(mysql_mod):
proto, transport, _ = _make_protocol(mysql_mod)
proto.data_received(_make_packet(b"", seq=1))
# Empty payload is silently skipped — no crash, no close
transport.close.assert_not_called()
def test_partial_header_waits_for_more(mysql_mod):
proto, transport, _ = _make_protocol(mysql_mod)
proto.data_received(b"\x00\x00\x00") # only 3 bytes — not enough
transport.close.assert_not_called()
def test_connection_lost_does_not_raise(mysql_mod):
proto, _, _ = _make_protocol(mysql_mod)
proto.connection_lost(None)
# ── Regression: oversized length field ───────────────────────────────────────
def test_length_over_1mb_closes(mysql_mod):
proto, transport, _ = _make_protocol(mysql_mod)
# 1MB + 1 in 3-byte LE: 0x100001 → b'\x01\x00\x10'
over_1mb = struct.pack("<I", 1024 * 1024 + 1)[:3]
data = over_1mb + b"\x01" # seq=1
run_with_timeout(proto.data_received, data)
transport.close.assert_called()
def test_max_3byte_length_closes(mysql_mod):
proto, transport, _ = _make_protocol(mysql_mod)
# 0xFFFFFF = 16,777,215 — max representable in 3 bytes, clearly > 1MB cap
data = b"\xff\xff\xff\x01"
run_with_timeout(proto.data_received, data)
transport.close.assert_called()
def test_length_just_over_1mb_closes(mysql_mod):
proto, transport, _ = _make_protocol(mysql_mod)
# 1MB + 1 byte — just over the cap
just_over = struct.pack("<I", 1024 * 1024 + 1)[:3]
data = just_over + b"\x01"
run_with_timeout(proto.data_received, data)
transport.close.assert_called()
# ── Fuzz ──────────────────────────────────────────────────────────────────────
@pytest.mark.fuzz
@given(data=st.binary(min_size=0, max_size=512))
@settings(**_FUZZ_SETTINGS)
def test_fuzz_arbitrary_bytes(data):
mod = _load_mysql()
proto, _, _ = _make_protocol(mod)
run_with_timeout(proto.data_received, data)

View File

@@ -0,0 +1,189 @@
"""
Tests for templates/postgres/server.py
Covers the PostgreSQL startup / MD5-auth handshake happy path and regression
tests for zero/tiny/huge msg_len in both the startup and auth states.
"""
import importlib.util
import struct
import sys
from unittest.mock import MagicMock, patch
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from .conftest import _FUZZ_SETTINGS, make_fake_decnet_logging, run_with_timeout
# ── Helpers ───────────────────────────────────────────────────────────────────
def _load_postgres():
for key in list(sys.modules):
if key in ("postgres_server", "decnet_logging"):
del sys.modules[key]
sys.modules["decnet_logging"] = make_fake_decnet_logging()
spec = importlib.util.spec_from_file_location("postgres_server", "templates/postgres/server.py")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
def _make_protocol(mod):
proto = mod.PostgresProtocol()
transport = MagicMock()
written: list[bytes] = []
transport.write.side_effect = written.append
proto.connection_made(transport)
return proto, transport, written
def _startup_msg(user: str = "postgres", database: str = "postgres") -> bytes:
"""Build a valid PostgreSQL startup message."""
params = f"user\x00{user}\x00database\x00{database}\x00\x00".encode()
protocol = struct.pack(">I", 0x00030000)
body = protocol + params
msg_len = struct.pack(">I", 4 + len(body))
return msg_len + body
def _ssl_request() -> bytes:
return struct.pack(">II", 8, 80877103)
def _password_msg(password: str = "wrongpass") -> bytes:
pw = password.encode() + b"\x00"
return b"p" + struct.pack(">I", 4 + len(pw)) + pw
@pytest.fixture
def postgres_mod():
return _load_postgres()
# ── Happy path ────────────────────────────────────────────────────────────────
def test_ssl_request_returns_N(postgres_mod):
proto, _, written = _make_protocol(postgres_mod)
proto.data_received(_ssl_request())
assert b"N" in b"".join(written)
def test_startup_sends_auth_challenge(postgres_mod):
proto, _, written = _make_protocol(postgres_mod)
proto.data_received(_startup_msg())
resp = b"".join(written)
# AuthenticationMD5Password = 'R' + len(12) + type(5) + salt(4)
assert resp[0:1] == b"R"
def test_startup_logs_username():
mod = _load_postgres()
log_mock = sys.modules["decnet_logging"]
proto, _, _ = _make_protocol(mod)
proto.data_received(_startup_msg(user="attacker"))
log_mock.syslog_line.assert_called()
calls_str = str(log_mock.syslog_line.call_args_list)
assert "attacker" in calls_str
def test_state_becomes_auth_after_startup(postgres_mod):
proto, _, _ = _make_protocol(postgres_mod)
proto.data_received(_startup_msg())
assert proto._state == "auth"
def test_password_triggers_close(postgres_mod):
proto, transport, _ = _make_protocol(postgres_mod)
proto.data_received(_startup_msg())
transport.reset_mock()
proto.data_received(_password_msg())
transport.close.assert_called()
def test_partial_startup_waits(postgres_mod):
proto, transport, _ = _make_protocol(postgres_mod)
proto.data_received(b"\x00\x00\x00") # only 3 bytes — not enough for msg_len
transport.close.assert_not_called()
assert proto._state == "startup"
def test_connection_lost_does_not_raise(postgres_mod):
proto, _, _ = _make_protocol(postgres_mod)
proto.connection_lost(None)
# ── Regression: startup state bad msg_len ────────────────────────────────────
def test_zero_msg_len_startup_closes(postgres_mod):
proto, transport, _ = _make_protocol(postgres_mod)
run_with_timeout(proto.data_received, b"\x00\x00\x00\x00")
transport.close.assert_called()
def test_msg_len_4_startup_closes(postgres_mod):
proto, transport, _ = _make_protocol(postgres_mod)
# msg_len=4 means zero-byte body — too small for startup (needs protocol version)
run_with_timeout(proto.data_received, struct.pack(">I", 4) + b"\x00" * 4)
transport.close.assert_called()
def test_msg_len_7_startup_closes(postgres_mod):
proto, transport, _ = _make_protocol(postgres_mod)
run_with_timeout(proto.data_received, struct.pack(">I", 7) + b"\x00" * 7)
transport.close.assert_called()
def test_huge_msg_len_startup_closes(postgres_mod):
proto, transport, _ = _make_protocol(postgres_mod)
run_with_timeout(proto.data_received, struct.pack(">I", 0x7FFFFFFF) + b"\x00" * 4)
transport.close.assert_called()
# ── Regression: auth state bad msg_len ───────────────────────────────────────
def test_zero_msg_len_auth_closes(postgres_mod):
proto, transport, _ = _make_protocol(postgres_mod)
proto.data_received(_startup_msg())
transport.reset_mock()
# 'p' + msg_len=0
run_with_timeout(proto.data_received, b"p" + struct.pack(">I", 0))
transport.close.assert_called()
def test_msg_len_1_auth_closes(postgres_mod):
proto, transport, _ = _make_protocol(postgres_mod)
proto.data_received(_startup_msg())
transport.reset_mock()
run_with_timeout(proto.data_received, b"p" + struct.pack(">I", 1) + b"\x00" * 5)
transport.close.assert_called()
def test_huge_msg_len_auth_closes(postgres_mod):
proto, transport, _ = _make_protocol(postgres_mod)
proto.data_received(_startup_msg())
transport.reset_mock()
run_with_timeout(proto.data_received, b"p" + struct.pack(">I", 0x7FFFFFFF) + b"\x00" * 5)
transport.close.assert_called()
# ── Fuzz ──────────────────────────────────────────────────────────────────────
@pytest.mark.fuzz
@given(data=st.binary(min_size=0, max_size=512))
@settings(**_FUZZ_SETTINGS)
def test_fuzz_startup_state(data):
mod = _load_postgres()
proto, _, _ = _make_protocol(mod)
run_with_timeout(proto.data_received, data)
@pytest.mark.fuzz
@given(data=st.binary(min_size=0, max_size=512))
@settings(**_FUZZ_SETTINGS)
def test_fuzz_auth_state(data):
mod = _load_postgres()
proto, _, _ = _make_protocol(mod)
proto.data_received(_startup_msg())
run_with_timeout(proto.data_received, data)