fix(protocols): guard against zero/malformed length fields in binary protocol parsers
MongoDB had the same infinite-loop bug as MSSQL (msg_len=0 → buffer never shrinks in while loop). Postgres, MySQL, and MQTT had related length-field issues (stuck state, resource exhaustion, overlong remaining-length). Also fixes an existing MongoDB _op_reply struct.pack format bug (extra 'q' specifier caused struct.error on any OP_QUERY response). Adds 53 regression + protocol boundary tests across MSSQL, MongoDB, Postgres, MySQL, and MQTT, including a _run_with_timeout threading harness to catch infinite loops and @pytest.mark.fuzz hypothesis tests for each.
This commit is contained in:
@@ -35,13 +35,13 @@ def _op_reply(request_id: int, doc: bytes) -> bytes:
|
|||||||
# OP_REPLY header: total_len(4), req_id(4), response_to(4), opcode(4)=1,
|
# OP_REPLY header: total_len(4), req_id(4), response_to(4), opcode(4)=1,
|
||||||
# flags(4), cursor_id(8), starting_from(4), number_returned(4), docs
|
# flags(4), cursor_id(8), starting_from(4), number_returned(4), docs
|
||||||
header = struct.pack(
|
header = struct.pack(
|
||||||
"<iiiiiqqii",
|
"<iiiiiqii",
|
||||||
16 + 20 + len(doc), # total length
|
16 + 20 + len(doc), # total length
|
||||||
0, # request id
|
0, # request id
|
||||||
request_id, # response to
|
request_id, # response to
|
||||||
1, # OP_REPLY
|
1, # OP_REPLY
|
||||||
0, # flags
|
0, # flags
|
||||||
0, # cursor id
|
0, # cursor id (int64)
|
||||||
0, # starting from
|
0, # starting from
|
||||||
1, # number returned
|
1, # number returned
|
||||||
)
|
)
|
||||||
@@ -81,6 +81,10 @@ class MongoDBProtocol(asyncio.Protocol):
|
|||||||
self._buf += data
|
self._buf += data
|
||||||
while len(self._buf) >= 16:
|
while len(self._buf) >= 16:
|
||||||
msg_len = struct.unpack("<I", self._buf[:4])[0]
|
msg_len = struct.unpack("<I", self._buf[:4])[0]
|
||||||
|
if msg_len < 16 or msg_len > 48 * 1024 * 1024:
|
||||||
|
self._transport.close()
|
||||||
|
self._buf = b""
|
||||||
|
return
|
||||||
if len(self._buf) < msg_len:
|
if len(self._buf) < msg_len:
|
||||||
break
|
break
|
||||||
msg = self._buf[:msg_len]
|
msg = self._buf[:msg_len]
|
||||||
|
|||||||
@@ -191,6 +191,10 @@ class MQTTProtocol(asyncio.Protocol):
|
|||||||
remaining = 0
|
remaining = 0
|
||||||
multiplier = 1
|
multiplier = 1
|
||||||
while pos < len(self._buf):
|
while pos < len(self._buf):
|
||||||
|
if pos > 4: # MQTT spec: max 4 bytes for remaining length
|
||||||
|
self._transport.close()
|
||||||
|
self._buf = b""
|
||||||
|
return
|
||||||
byte = self._buf[pos]
|
byte = self._buf[pos]
|
||||||
remaining += (byte & 0x7f) * multiplier
|
remaining += (byte & 0x7f) * multiplier
|
||||||
multiplier *= 128
|
multiplier *= 128
|
||||||
|
|||||||
@@ -67,6 +67,10 @@ class MySQLProtocol(asyncio.Protocol):
|
|||||||
# MySQL packets: 3-byte length + 1-byte seq + payload
|
# MySQL packets: 3-byte length + 1-byte seq + payload
|
||||||
while len(self._buf) >= 4:
|
while len(self._buf) >= 4:
|
||||||
length = struct.unpack("<I", self._buf[:3] + b"\x00")[0]
|
length = struct.unpack("<I", self._buf[:3] + b"\x00")[0]
|
||||||
|
if length > 1024 * 1024:
|
||||||
|
self._transport.close()
|
||||||
|
self._buf = b""
|
||||||
|
return
|
||||||
if len(self._buf) < 4 + length:
|
if len(self._buf) < 4 + length:
|
||||||
break
|
break
|
||||||
payload = self._buf[4:4 + length]
|
payload = self._buf[4:4 + length]
|
||||||
|
|||||||
@@ -49,6 +49,10 @@ class PostgresProtocol(asyncio.Protocol):
|
|||||||
if len(self._buf) < 4:
|
if len(self._buf) < 4:
|
||||||
return
|
return
|
||||||
msg_len = struct.unpack(">I", self._buf[:4])[0]
|
msg_len = struct.unpack(">I", self._buf[:4])[0]
|
||||||
|
if msg_len < 8 or msg_len > 10_000:
|
||||||
|
self._transport.close()
|
||||||
|
self._buf = b""
|
||||||
|
return
|
||||||
if len(self._buf) < msg_len:
|
if len(self._buf) < msg_len:
|
||||||
return
|
return
|
||||||
msg = self._buf[:msg_len]
|
msg = self._buf[:msg_len]
|
||||||
@@ -59,6 +63,10 @@ class PostgresProtocol(asyncio.Protocol):
|
|||||||
return
|
return
|
||||||
msg_type = chr(self._buf[0])
|
msg_type = chr(self._buf[0])
|
||||||
msg_len = struct.unpack(">I", self._buf[1:5])[0]
|
msg_len = struct.unpack(">I", self._buf[1:5])[0]
|
||||||
|
if msg_len < 4 or msg_len > 10_000:
|
||||||
|
self._transport.close()
|
||||||
|
self._buf = b""
|
||||||
|
return
|
||||||
if len(self._buf) < msg_len + 1:
|
if len(self._buf) < msg_len + 1:
|
||||||
return
|
return
|
||||||
payload = self._buf[5:msg_len + 1]
|
payload = self._buf[5:msg_len + 1]
|
||||||
|
|||||||
47
tests/service_testing/conftest.py
Normal file
47
tests/service_testing/conftest.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""
|
||||||
|
Shared helpers for binary-protocol service tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from types import ModuleType
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from hypothesis import HealthCheck
|
||||||
|
|
||||||
|
|
||||||
|
_FUZZ_SETTINGS = dict(
|
||||||
|
max_examples=int(os.environ.get("HYPOTHESIS_MAX_EXAMPLES", "200")),
|
||||||
|
deadline=2000,
|
||||||
|
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_fake_decnet_logging() -> ModuleType:
|
||||||
|
mod = ModuleType("decnet_logging")
|
||||||
|
mod.syslog_line = MagicMock(return_value="")
|
||||||
|
mod.write_syslog_file = MagicMock()
|
||||||
|
mod.forward_syslog = MagicMock()
|
||||||
|
mod.SEVERITY_WARNING = 4
|
||||||
|
mod.SEVERITY_INFO = 6
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
def run_with_timeout(fn, *args, timeout: float = 2.0) -> None:
|
||||||
|
"""Run fn(*args) in a daemon thread. pytest.fail if it doesn't return in time."""
|
||||||
|
exc_box: list[BaseException] = []
|
||||||
|
|
||||||
|
def _target():
|
||||||
|
try:
|
||||||
|
fn(*args)
|
||||||
|
except Exception as e:
|
||||||
|
exc_box.append(e)
|
||||||
|
|
||||||
|
t = threading.Thread(target=_target, daemon=True)
|
||||||
|
t.start()
|
||||||
|
t.join(timeout)
|
||||||
|
if t.is_alive():
|
||||||
|
pytest.fail(f"data_received hung for >{timeout}s — likely infinite loop")
|
||||||
|
if exc_box:
|
||||||
|
raise exc_box[0]
|
||||||
161
tests/service_testing/test_mongodb.py
Normal file
161
tests/service_testing/test_mongodb.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
"""
|
||||||
|
Tests for templates/mongodb/server.py
|
||||||
|
|
||||||
|
Covers the MongoDB wire-protocol (OP_MSG / OP_QUERY) happy path and regression
|
||||||
|
tests for the zero-length msg_len infinite-loop bug and oversized msg_len.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from hypothesis import given, settings
|
||||||
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
|
from .conftest import _FUZZ_SETTINGS, make_fake_decnet_logging, run_with_timeout
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _load_mongodb():
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key in ("mongodb_server", "decnet_logging"):
|
||||||
|
del sys.modules[key]
|
||||||
|
sys.modules["decnet_logging"] = make_fake_decnet_logging()
|
||||||
|
spec = importlib.util.spec_from_file_location("mongodb_server", "templates/mongodb/server.py")
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(mod)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
def _make_protocol(mod):
|
||||||
|
proto = mod.MongoDBProtocol()
|
||||||
|
transport = MagicMock()
|
||||||
|
written: list[bytes] = []
|
||||||
|
transport.write.side_effect = written.append
|
||||||
|
proto.connection_made(transport)
|
||||||
|
return proto, transport, written
|
||||||
|
|
||||||
|
|
||||||
|
def _minimal_bson() -> bytes:
|
||||||
|
return b"\x05\x00\x00\x00\x00" # empty document
|
||||||
|
|
||||||
|
|
||||||
|
def _op_msg_packet(request_id: int = 1) -> bytes:
|
||||||
|
"""Build a valid OP_MSG with an empty BSON body."""
|
||||||
|
flag_bits = struct.pack("<I", 0)
|
||||||
|
section = b"\x00" + _minimal_bson()
|
||||||
|
body = flag_bits + section
|
||||||
|
total = 16 + len(body)
|
||||||
|
header = struct.pack("<iiii", total, request_id, 0, 2013)
|
||||||
|
return header + body
|
||||||
|
|
||||||
|
|
||||||
|
def _op_query_packet(request_id: int = 2) -> bytes:
|
||||||
|
"""Build a minimal OP_QUERY."""
|
||||||
|
flags = struct.pack("<I", 0)
|
||||||
|
coll = b"admin.$cmd\x00"
|
||||||
|
skip = struct.pack("<I", 0)
|
||||||
|
ret = struct.pack("<I", 1)
|
||||||
|
query = _minimal_bson()
|
||||||
|
body = flags + coll + skip + ret + query
|
||||||
|
total = 16 + len(body)
|
||||||
|
header = struct.pack("<iiii", total, request_id, 0, 2004)
|
||||||
|
return header + body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mongodb_mod():
|
||||||
|
return _load_mongodb()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Happy path ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_op_msg_returns_response(mongodb_mod):
|
||||||
|
proto, _, written = _make_protocol(mongodb_mod)
|
||||||
|
proto.data_received(_op_msg_packet())
|
||||||
|
assert written, "expected a response to OP_MSG"
|
||||||
|
|
||||||
|
|
||||||
|
def test_op_msg_response_opcode_is_2013(mongodb_mod):
|
||||||
|
proto, _, written = _make_protocol(mongodb_mod)
|
||||||
|
proto.data_received(_op_msg_packet())
|
||||||
|
resp = b"".join(written)
|
||||||
|
assert len(resp) >= 16
|
||||||
|
opcode = struct.unpack("<i", resp[12:16])[0]
|
||||||
|
assert opcode == 2013
|
||||||
|
|
||||||
|
|
||||||
|
def test_op_query_returns_op_reply(mongodb_mod):
|
||||||
|
proto, _, written = _make_protocol(mongodb_mod)
|
||||||
|
proto.data_received(_op_query_packet())
|
||||||
|
resp = b"".join(written)
|
||||||
|
assert len(resp) >= 16
|
||||||
|
opcode = struct.unpack("<i", resp[12:16])[0]
|
||||||
|
assert opcode == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_header_waits_for_more_data(mongodb_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mongodb_mod)
|
||||||
|
proto.data_received(b"\x1a\x00\x00\x00") # only 4 bytes (< 16)
|
||||||
|
transport.close.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_two_consecutive_messages(mongodb_mod):
|
||||||
|
proto, _, written = _make_protocol(mongodb_mod)
|
||||||
|
two = _op_msg_packet(1) + _op_msg_packet(2)
|
||||||
|
proto.data_received(two)
|
||||||
|
assert len(written) >= 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_connection_lost_does_not_raise(mongodb_mod):
|
||||||
|
proto, _, _ = _make_protocol(mongodb_mod)
|
||||||
|
proto.connection_lost(None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Regression: malformed msg_len ────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_zero_msg_len_closes(mongodb_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mongodb_mod)
|
||||||
|
# msg_len = 0 at bytes [0:4] LE — buffer has 16 bytes so outer while triggers
|
||||||
|
data = b"\x00\x00\x00\x00" + b"\x00" * 12
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_msg_len_15_closes(mongodb_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mongodb_mod)
|
||||||
|
# msg_len = 15 (below 16-byte wire-protocol minimum)
|
||||||
|
data = struct.pack("<I", 15) + b"\x00" * 12
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_msg_len_over_48mb_closes(mongodb_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mongodb_mod)
|
||||||
|
# msg_len = 48MB + 1
|
||||||
|
big = 48 * 1024 * 1024 + 1
|
||||||
|
data = struct.pack("<I", big) + b"\x00" * 12
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_msg_len_exactly_48mb_plus1_closes(mongodb_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mongodb_mod)
|
||||||
|
# cap is strictly > 48MB, so 48MB+1 must close
|
||||||
|
data = struct.pack("<I", 48 * 1024 * 1024 + 1) + b"\x00" * 12
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fuzz ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.fuzz
|
||||||
|
@given(data=st.binary(min_size=0, max_size=512))
|
||||||
|
@settings(**_FUZZ_SETTINGS)
|
||||||
|
def test_fuzz_arbitrary_bytes(data):
|
||||||
|
mod = _load_mongodb()
|
||||||
|
proto, _, _ = _make_protocol(mod)
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
185
tests/service_testing/test_mqtt_fuzz.py
Normal file
185
tests/service_testing/test_mqtt_fuzz.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
"""
|
||||||
|
Tests for templates/mqtt/server.py — protocol boundary and fuzz cases.
|
||||||
|
|
||||||
|
Focuses on the variable-length remaining-length field (MQTT spec: max 4 bytes).
|
||||||
|
A 5th continuation byte used to cause the server to get stuck waiting for a
|
||||||
|
payload it could never receive (remaining = hundreds of MB).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from hypothesis import given, settings
|
||||||
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
|
from .conftest import _FUZZ_SETTINGS, make_fake_decnet_logging, run_with_timeout
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _load_mqtt():
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key in ("mqtt_server", "decnet_logging"):
|
||||||
|
del sys.modules[key]
|
||||||
|
sys.modules["decnet_logging"] = make_fake_decnet_logging()
|
||||||
|
spec = importlib.util.spec_from_file_location("mqtt_server", "templates/mqtt/server.py")
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
with patch.dict("os.environ", {"MQTT_ACCEPT_ALL": "1", "MQTT_PERSONA": "water_plant"}, clear=False):
|
||||||
|
spec.loader.exec_module(mod)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
def _make_protocol(mod):
|
||||||
|
proto = mod.MQTTProtocol()
|
||||||
|
transport = MagicMock()
|
||||||
|
written: list[bytes] = []
|
||||||
|
transport.write.side_effect = written.append
|
||||||
|
proto.connection_made(transport)
|
||||||
|
return proto, transport, written
|
||||||
|
|
||||||
|
|
||||||
|
def _connect_packet(client_id: str = "test-client") -> bytes:
|
||||||
|
"""Build a minimal MQTT CONNECT packet."""
|
||||||
|
proto_name = b"\x00\x04MQTT"
|
||||||
|
proto_level = b"\x04" # 3.1.1
|
||||||
|
flags = b"\x02" # clean session
|
||||||
|
keepalive = b"\x00\x3c"
|
||||||
|
cid = client_id.encode()
|
||||||
|
cid_field = struct.pack(">H", len(cid)) + cid
|
||||||
|
payload = proto_name + proto_level + flags + keepalive + cid_field
|
||||||
|
remaining = len(payload)
|
||||||
|
# single-byte remaining length (works for short payloads)
|
||||||
|
return bytes([0x10, remaining]) + payload
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_remaining(value: int) -> bytes:
|
||||||
|
"""Encode a value using MQTT variable-length encoding."""
|
||||||
|
result = []
|
||||||
|
while True:
|
||||||
|
encoded = value % 128
|
||||||
|
value //= 128
|
||||||
|
if value > 0:
|
||||||
|
encoded |= 128
|
||||||
|
result.append(encoded)
|
||||||
|
if value == 0:
|
||||||
|
break
|
||||||
|
return bytes(result)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mqtt_mod():
|
||||||
|
return _load_mqtt()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Happy path ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_connect_returns_connack_accepted(mqtt_mod):
|
||||||
|
proto, _, written = _make_protocol(mqtt_mod)
|
||||||
|
proto.data_received(_connect_packet())
|
||||||
|
resp = b"".join(written)
|
||||||
|
assert resp[:2] == b"\x20\x02" # CONNACK
|
||||||
|
assert resp[3:4] == b"\x00" # return code 0 = accepted
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_sets_auth_flag(mqtt_mod):
|
||||||
|
proto, _, _ = _make_protocol(mqtt_mod)
|
||||||
|
proto.data_received(_connect_packet())
|
||||||
|
assert proto._auth is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_pingreq_returns_pingresp(mqtt_mod):
|
||||||
|
proto, _, written = _make_protocol(mqtt_mod)
|
||||||
|
proto.data_received(_connect_packet())
|
||||||
|
written.clear()
|
||||||
|
proto.data_received(b"\xc0\x00") # PINGREQ
|
||||||
|
assert b"\xd0\x00" in b"".join(written)
|
||||||
|
|
||||||
|
|
||||||
|
def test_disconnect_closes_transport(mqtt_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mqtt_mod)
|
||||||
|
proto.data_received(_connect_packet())
|
||||||
|
transport.reset_mock()
|
||||||
|
proto.data_received(b"\xe0\x00") # DISCONNECT
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_publish_without_auth_closes(mqtt_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mqtt_mod)
|
||||||
|
# PUBLISH without prior CONNECT
|
||||||
|
topic = b"\x00\x04test"
|
||||||
|
payload = b"hello"
|
||||||
|
remaining = len(topic) + len(payload)
|
||||||
|
proto.data_received(bytes([0x30, remaining]) + topic + payload)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_packet_waits_for_more(mqtt_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mqtt_mod)
|
||||||
|
proto.data_received(b"\x10") # just the first byte
|
||||||
|
transport.close.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_connection_lost_does_not_raise(mqtt_mod):
|
||||||
|
proto, _, _ = _make_protocol(mqtt_mod)
|
||||||
|
proto.connection_lost(None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Regression: overlong remaining-length field ───────────────────────────────
|
||||||
|
|
||||||
|
def test_5_continuation_bytes_closes(mqtt_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mqtt_mod)
|
||||||
|
# 5 bytes with continuation bit set, then a final byte
|
||||||
|
# MQTT spec allows max 4 bytes — this must be rejected
|
||||||
|
data = bytes([0x30, 0x80, 0x80, 0x80, 0x80, 0x01])
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_6_continuation_bytes_closes(mqtt_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mqtt_mod)
|
||||||
|
data = bytes([0x30]) + bytes([0x80] * 6) + b"\x01"
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_4_continuation_bytes_is_accepted(mqtt_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mqtt_mod)
|
||||||
|
# 4 bytes total for remaining length = max allowed.
|
||||||
|
# remaining = 0x0FFFFFFF = 268435455 bytes — huge but spec-valid encoding.
|
||||||
|
# With no data following, it simply returns (incomplete payload) — not closed.
|
||||||
|
data = bytes([0x30, 0xff, 0xff, 0xff, 0x7f])
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
|
transport.close.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_zero_remaining_publish_does_not_close(mqtt_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mqtt_mod)
|
||||||
|
proto.data_received(_connect_packet())
|
||||||
|
transport.reset_mock()
|
||||||
|
# PUBLISH with remaining=0 is unusual but not a protocol violation
|
||||||
|
proto.data_received(b"\x30\x00")
|
||||||
|
transport.close.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fuzz ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.fuzz
|
||||||
|
@given(data=st.binary(min_size=0, max_size=512))
|
||||||
|
@settings(**_FUZZ_SETTINGS)
|
||||||
|
def test_fuzz_unauthenticated(data):
|
||||||
|
mod = _load_mqtt()
|
||||||
|
proto, _, _ = _make_protocol(mod)
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.fuzz
|
||||||
|
@given(data=st.binary(min_size=0, max_size=512))
|
||||||
|
@settings(**_FUZZ_SETTINGS)
|
||||||
|
def test_fuzz_after_connect(data):
|
||||||
|
mod = _load_mqtt()
|
||||||
|
proto, _, _ = _make_protocol(mod)
|
||||||
|
proto.data_received(_connect_packet())
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
137
tests/service_testing/test_mssql.py
Normal file
137
tests/service_testing/test_mssql.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""
|
||||||
|
Tests for templates/mssql/server.py
|
||||||
|
|
||||||
|
Covers the TDS pre-login / login7 happy path and regression tests for the
|
||||||
|
zero-length pkt_len infinite-loop bug that was fixed (pkt_len < 8 guard).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from hypothesis import given, settings
|
||||||
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
|
from .conftest import _FUZZ_SETTINGS, make_fake_decnet_logging, run_with_timeout
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _load_mssql():
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key in ("mssql_server", "decnet_logging"):
|
||||||
|
del sys.modules[key]
|
||||||
|
sys.modules["decnet_logging"] = make_fake_decnet_logging()
|
||||||
|
spec = importlib.util.spec_from_file_location("mssql_server", "templates/mssql/server.py")
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(mod)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
def _make_protocol(mod):
|
||||||
|
proto = mod.MSSQLProtocol()
|
||||||
|
transport = MagicMock()
|
||||||
|
written: list[bytes] = []
|
||||||
|
transport.write.side_effect = written.append
|
||||||
|
transport.is_closing.return_value = False
|
||||||
|
proto.connection_made(transport)
|
||||||
|
return proto, transport, written
|
||||||
|
|
||||||
|
|
||||||
|
def _tds_header(pkt_type: int, pkt_len: int) -> bytes:
|
||||||
|
"""Build an 8-byte TDS packet header."""
|
||||||
|
return struct.pack(">BBHBBBB", pkt_type, 0x01, pkt_len, 0x00, 0x00, 0x01, 0x00)
|
||||||
|
|
||||||
|
|
||||||
|
def _prelogin_packet() -> bytes:
|
||||||
|
header = _tds_header(0x12, 8)
|
||||||
|
return header
|
||||||
|
|
||||||
|
|
||||||
|
def _login7_packet() -> bytes:
|
||||||
|
"""Minimal Login7 with 40-byte payload (username at offset 0, length 0)."""
|
||||||
|
payload = b"\x00" * 40
|
||||||
|
pkt_len = 8 + len(payload)
|
||||||
|
header = _tds_header(0x10, pkt_len)
|
||||||
|
return header + payload
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mssql_mod():
|
||||||
|
return _load_mssql()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Happy path ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_prelogin_response_is_tds_type4(mssql_mod):
|
||||||
|
proto, _, written = _make_protocol(mssql_mod)
|
||||||
|
proto.data_received(_prelogin_packet())
|
||||||
|
assert written, "expected a pre-login response"
|
||||||
|
assert written[0][0] == 0x04
|
||||||
|
|
||||||
|
|
||||||
|
def test_prelogin_response_length_matches_header(mssql_mod):
|
||||||
|
proto, _, written = _make_protocol(mssql_mod)
|
||||||
|
proto.data_received(_prelogin_packet())
|
||||||
|
resp = b"".join(written)
|
||||||
|
declared_len = struct.unpack(">H", resp[2:4])[0]
|
||||||
|
assert declared_len == len(resp)
|
||||||
|
|
||||||
|
|
||||||
|
def test_login7_auth_logged_and_closes(mssql_mod):
|
||||||
|
proto, transport, written = _make_protocol(mssql_mod)
|
||||||
|
proto.data_received(_prelogin_packet())
|
||||||
|
written.clear()
|
||||||
|
proto.data_received(_login7_packet())
|
||||||
|
transport.close.assert_called()
|
||||||
|
# error packet must be present
|
||||||
|
assert any(b"\xaa" in chunk for chunk in written)
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_header_waits_for_more_data(mssql_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mssql_mod)
|
||||||
|
proto.data_received(b"\x12\x01")
|
||||||
|
transport.close.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_connection_lost_does_not_raise(mssql_mod):
|
||||||
|
proto, _, _ = _make_protocol(mssql_mod)
|
||||||
|
proto.connection_lost(None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Regression: zero / small pkt_len ─────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_zero_pkt_len_closes(mssql_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mssql_mod)
|
||||||
|
# pkt_len = 0x0000 at bytes [2:4]
|
||||||
|
data = b"\x12\x01\x00\x00\x00\x00\x01\x00"
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_pkt_len_7_closes(mssql_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mssql_mod)
|
||||||
|
# pkt_len = 7 (< 8 minimum)
|
||||||
|
data = _tds_header(0x12, 7) + b"\x00"
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_pkt_len_1_closes(mssql_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mssql_mod)
|
||||||
|
data = _tds_header(0x12, 1) + b"\x00" * 7
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fuzz ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.fuzz
|
||||||
|
@given(data=st.binary(min_size=0, max_size=512))
|
||||||
|
@settings(**_FUZZ_SETTINGS)
|
||||||
|
def test_fuzz_arbitrary_bytes(data):
|
||||||
|
mod = _load_mssql()
|
||||||
|
proto, _, _ = _make_protocol(mod)
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
153
tests/service_testing/test_mysql.py
Normal file
153
tests/service_testing/test_mysql.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""
|
||||||
|
Tests for templates/mysql/server.py
|
||||||
|
|
||||||
|
Covers the MySQL handshake happy path and regression tests for oversized
|
||||||
|
length fields that could cause huge buffer allocations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from hypothesis import given, settings
|
||||||
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
|
from .conftest import _FUZZ_SETTINGS, make_fake_decnet_logging, run_with_timeout
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _load_mysql():
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key in ("mysql_server", "decnet_logging"):
|
||||||
|
del sys.modules[key]
|
||||||
|
sys.modules["decnet_logging"] = make_fake_decnet_logging()
|
||||||
|
spec = importlib.util.spec_from_file_location("mysql_server", "templates/mysql/server.py")
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(mod)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
def _make_protocol(mod):
|
||||||
|
proto = mod.MySQLProtocol()
|
||||||
|
transport = MagicMock()
|
||||||
|
written: list[bytes] = []
|
||||||
|
transport.write.side_effect = written.append
|
||||||
|
proto.connection_made(transport)
|
||||||
|
written.clear() # clear the greeting sent on connect
|
||||||
|
return proto, transport, written
|
||||||
|
|
||||||
|
|
||||||
|
def _make_packet(payload: bytes, seq: int = 1) -> bytes:
|
||||||
|
length = len(payload)
|
||||||
|
return struct.pack("<I", length)[:3] + bytes([seq]) + payload
|
||||||
|
|
||||||
|
|
||||||
|
def _login_packet(username: str = "root") -> bytes:
|
||||||
|
"""Minimal MySQL client login packet."""
|
||||||
|
caps = struct.pack("<I", 0x000FA685)
|
||||||
|
max_pkt = struct.pack("<I", 16777216)
|
||||||
|
charset = b"\x21"
|
||||||
|
reserved = b"\x00" * 23
|
||||||
|
uname = username.encode() + b"\x00"
|
||||||
|
payload = caps + max_pkt + charset + reserved + uname
|
||||||
|
return _make_packet(payload, seq=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mysql_mod():
|
||||||
|
return _load_mysql()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Happy path ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_connection_sends_greeting(mysql_mod):
|
||||||
|
proto = mysql_mod.MySQLProtocol()
|
||||||
|
transport = MagicMock()
|
||||||
|
written: list[bytes] = []
|
||||||
|
transport.write.side_effect = written.append
|
||||||
|
proto.connection_made(transport)
|
||||||
|
greeting = b"".join(written)
|
||||||
|
assert greeting[4] == 0x0a # protocol v10
|
||||||
|
assert b"mysql_native_password" in greeting
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_packet_triggers_close(mysql_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mysql_mod)
|
||||||
|
proto.data_received(_login_packet())
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_packet_returns_access_denied(mysql_mod):
|
||||||
|
proto, _, written = _make_protocol(mysql_mod)
|
||||||
|
proto.data_received(_login_packet())
|
||||||
|
resp = b"".join(written)
|
||||||
|
assert b"\xff" in resp # error packet marker
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_logs_username():
|
||||||
|
mod = _load_mysql()
|
||||||
|
log_mock = sys.modules["decnet_logging"]
|
||||||
|
proto, _, _ = _make_protocol(mod)
|
||||||
|
proto.data_received(_login_packet(username="hacker"))
|
||||||
|
calls_str = str(log_mock.syslog_line.call_args_list)
|
||||||
|
assert "hacker" in calls_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_payload_packet_does_not_crash(mysql_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mysql_mod)
|
||||||
|
proto.data_received(_make_packet(b"", seq=1))
|
||||||
|
# Empty payload is silently skipped — no crash, no close
|
||||||
|
transport.close.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_header_waits_for_more(mysql_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mysql_mod)
|
||||||
|
proto.data_received(b"\x00\x00\x00") # only 3 bytes — not enough
|
||||||
|
transport.close.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_connection_lost_does_not_raise(mysql_mod):
|
||||||
|
proto, _, _ = _make_protocol(mysql_mod)
|
||||||
|
proto.connection_lost(None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Regression: oversized length field ───────────────────────────────────────
|
||||||
|
|
||||||
|
def test_length_over_1mb_closes(mysql_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mysql_mod)
|
||||||
|
# 1MB + 1 in 3-byte LE: 0x100001 → b'\x01\x00\x10'
|
||||||
|
over_1mb = struct.pack("<I", 1024 * 1024 + 1)[:3]
|
||||||
|
data = over_1mb + b"\x01" # seq=1
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_3byte_length_closes(mysql_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mysql_mod)
|
||||||
|
# 0xFFFFFF = 16,777,215 — max representable in 3 bytes, clearly > 1MB cap
|
||||||
|
data = b"\xff\xff\xff\x01"
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_length_just_over_1mb_closes(mysql_mod):
|
||||||
|
proto, transport, _ = _make_protocol(mysql_mod)
|
||||||
|
# 1MB + 1 byte — just over the cap
|
||||||
|
just_over = struct.pack("<I", 1024 * 1024 + 1)[:3]
|
||||||
|
data = just_over + b"\x01"
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fuzz ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.fuzz
|
||||||
|
@given(data=st.binary(min_size=0, max_size=512))
|
||||||
|
@settings(**_FUZZ_SETTINGS)
|
||||||
|
def test_fuzz_arbitrary_bytes(data):
|
||||||
|
mod = _load_mysql()
|
||||||
|
proto, _, _ = _make_protocol(mod)
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
189
tests/service_testing/test_postgres.py
Normal file
189
tests/service_testing/test_postgres.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
Tests for templates/postgres/server.py
|
||||||
|
|
||||||
|
Covers the PostgreSQL startup / MD5-auth handshake happy path and regression
|
||||||
|
tests for zero/tiny/huge msg_len in both the startup and auth states.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from hypothesis import given, settings
|
||||||
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
|
from .conftest import _FUZZ_SETTINGS, make_fake_decnet_logging, run_with_timeout
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _load_postgres():
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key in ("postgres_server", "decnet_logging"):
|
||||||
|
del sys.modules[key]
|
||||||
|
sys.modules["decnet_logging"] = make_fake_decnet_logging()
|
||||||
|
spec = importlib.util.spec_from_file_location("postgres_server", "templates/postgres/server.py")
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(mod)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
def _make_protocol(mod):
|
||||||
|
proto = mod.PostgresProtocol()
|
||||||
|
transport = MagicMock()
|
||||||
|
written: list[bytes] = []
|
||||||
|
transport.write.side_effect = written.append
|
||||||
|
proto.connection_made(transport)
|
||||||
|
return proto, transport, written
|
||||||
|
|
||||||
|
|
||||||
|
def _startup_msg(user: str = "postgres", database: str = "postgres") -> bytes:
|
||||||
|
"""Build a valid PostgreSQL startup message."""
|
||||||
|
params = f"user\x00{user}\x00database\x00{database}\x00\x00".encode()
|
||||||
|
protocol = struct.pack(">I", 0x00030000)
|
||||||
|
body = protocol + params
|
||||||
|
msg_len = struct.pack(">I", 4 + len(body))
|
||||||
|
return msg_len + body
|
||||||
|
|
||||||
|
|
||||||
|
def _ssl_request() -> bytes:
|
||||||
|
return struct.pack(">II", 8, 80877103)
|
||||||
|
|
||||||
|
|
||||||
|
def _password_msg(password: str = "wrongpass") -> bytes:
|
||||||
|
pw = password.encode() + b"\x00"
|
||||||
|
return b"p" + struct.pack(">I", 4 + len(pw)) + pw
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def postgres_mod():
|
||||||
|
return _load_postgres()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Happy path ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_ssl_request_returns_N(postgres_mod):
|
||||||
|
proto, _, written = _make_protocol(postgres_mod)
|
||||||
|
proto.data_received(_ssl_request())
|
||||||
|
assert b"N" in b"".join(written)
|
||||||
|
|
||||||
|
|
||||||
|
def test_startup_sends_auth_challenge(postgres_mod):
|
||||||
|
proto, _, written = _make_protocol(postgres_mod)
|
||||||
|
proto.data_received(_startup_msg())
|
||||||
|
resp = b"".join(written)
|
||||||
|
# AuthenticationMD5Password = 'R' + len(12) + type(5) + salt(4)
|
||||||
|
assert resp[0:1] == b"R"
|
||||||
|
|
||||||
|
|
||||||
|
def test_startup_logs_username():
|
||||||
|
mod = _load_postgres()
|
||||||
|
log_mock = sys.modules["decnet_logging"]
|
||||||
|
proto, _, _ = _make_protocol(mod)
|
||||||
|
proto.data_received(_startup_msg(user="attacker"))
|
||||||
|
log_mock.syslog_line.assert_called()
|
||||||
|
calls_str = str(log_mock.syslog_line.call_args_list)
|
||||||
|
assert "attacker" in calls_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_becomes_auth_after_startup(postgres_mod):
|
||||||
|
proto, _, _ = _make_protocol(postgres_mod)
|
||||||
|
proto.data_received(_startup_msg())
|
||||||
|
assert proto._state == "auth"
|
||||||
|
|
||||||
|
|
||||||
|
def test_password_triggers_close(postgres_mod):
|
||||||
|
proto, transport, _ = _make_protocol(postgres_mod)
|
||||||
|
proto.data_received(_startup_msg())
|
||||||
|
transport.reset_mock()
|
||||||
|
proto.data_received(_password_msg())
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_startup_waits(postgres_mod):
|
||||||
|
proto, transport, _ = _make_protocol(postgres_mod)
|
||||||
|
proto.data_received(b"\x00\x00\x00") # only 3 bytes — not enough for msg_len
|
||||||
|
transport.close.assert_not_called()
|
||||||
|
assert proto._state == "startup"
|
||||||
|
|
||||||
|
|
||||||
|
def test_connection_lost_does_not_raise(postgres_mod):
|
||||||
|
proto, _, _ = _make_protocol(postgres_mod)
|
||||||
|
proto.connection_lost(None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Regression: startup state bad msg_len ────────────────────────────────────
|
||||||
|
|
||||||
|
def test_zero_msg_len_startup_closes(postgres_mod):
|
||||||
|
proto, transport, _ = _make_protocol(postgres_mod)
|
||||||
|
run_with_timeout(proto.data_received, b"\x00\x00\x00\x00")
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_msg_len_4_startup_closes(postgres_mod):
|
||||||
|
proto, transport, _ = _make_protocol(postgres_mod)
|
||||||
|
# msg_len=4 means zero-byte body — too small for startup (needs protocol version)
|
||||||
|
run_with_timeout(proto.data_received, struct.pack(">I", 4) + b"\x00" * 4)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_msg_len_7_startup_closes(postgres_mod):
|
||||||
|
proto, transport, _ = _make_protocol(postgres_mod)
|
||||||
|
run_with_timeout(proto.data_received, struct.pack(">I", 7) + b"\x00" * 7)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_huge_msg_len_startup_closes(postgres_mod):
|
||||||
|
proto, transport, _ = _make_protocol(postgres_mod)
|
||||||
|
run_with_timeout(proto.data_received, struct.pack(">I", 0x7FFFFFFF) + b"\x00" * 4)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Regression: auth state bad msg_len ───────────────────────────────────────
|
||||||
|
|
||||||
|
def test_zero_msg_len_auth_closes(postgres_mod):
|
||||||
|
proto, transport, _ = _make_protocol(postgres_mod)
|
||||||
|
proto.data_received(_startup_msg())
|
||||||
|
transport.reset_mock()
|
||||||
|
# 'p' + msg_len=0
|
||||||
|
run_with_timeout(proto.data_received, b"p" + struct.pack(">I", 0))
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_msg_len_1_auth_closes(postgres_mod):
|
||||||
|
proto, transport, _ = _make_protocol(postgres_mod)
|
||||||
|
proto.data_received(_startup_msg())
|
||||||
|
transport.reset_mock()
|
||||||
|
run_with_timeout(proto.data_received, b"p" + struct.pack(">I", 1) + b"\x00" * 5)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_huge_msg_len_auth_closes(postgres_mod):
|
||||||
|
proto, transport, _ = _make_protocol(postgres_mod)
|
||||||
|
proto.data_received(_startup_msg())
|
||||||
|
transport.reset_mock()
|
||||||
|
run_with_timeout(proto.data_received, b"p" + struct.pack(">I", 0x7FFFFFFF) + b"\x00" * 5)
|
||||||
|
transport.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fuzz ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.fuzz
|
||||||
|
@given(data=st.binary(min_size=0, max_size=512))
|
||||||
|
@settings(**_FUZZ_SETTINGS)
|
||||||
|
def test_fuzz_startup_state(data):
|
||||||
|
mod = _load_postgres()
|
||||||
|
proto, _, _ = _make_protocol(mod)
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.fuzz
|
||||||
|
@given(data=st.binary(min_size=0, max_size=512))
|
||||||
|
@settings(**_FUZZ_SETTINGS)
|
||||||
|
def test_fuzz_auth_state(data):
|
||||||
|
mod = _load_postgres()
|
||||||
|
proto, _, _ = _make_protocol(mod)
|
||||||
|
proto.data_received(_startup_msg())
|
||||||
|
run_with_timeout(proto.data_received, data)
|
||||||
Reference in New Issue
Block a user