fix(types): P3 — annotate transport in all template protocol servers; 0 errors in templates/
- asyncio.Protocol (TCP): _transport: asyncio.Transport | None = None + cast() in connection_made; assert guards in every method that directly accesses the field. Files: pop3, smtp, mqtt, postgres, mssql, mongodb, imap, ldap, redis, mysql, sip, vnc. - asyncio.DatagramProtocol (UDP): _transport: asyncio.DatagramTransport | None = None. Files: snmp, tftp, SIPUDPProtocol. - RDP: assert new_transport is not None after start_tls() to narrow Transport | None. - FTP (Twisted): assert self.transport is not None + targeted type: ignore for imprecise Twisted stubs (misc/override/arg-type/attr-defined), IReactorTCP cast for listenTCP. - conpot: proc.stdout is None guard before iteration. - Bonus fixes surfaced by annotation: - smtp: get_payload(decode=True) bytes narrowing (arg-type on sha256) - postgres: rename shadowed `msg` param to `err_msg` in _handle_startup - mongodb: base64.binascii.Error → import binascii; binascii.Error - imap: result: list[int] = [] (var-annotated)
This commit is contained in:
@@ -128,6 +128,9 @@ def main():
|
||||
signal.signal(signal.SIGINT, _forward)
|
||||
|
||||
try:
|
||||
if proc.stdout is None:
|
||||
proc.wait()
|
||||
return
|
||||
for raw_line in proc.stdout:
|
||||
line = raw_line.rstrip()
|
||||
if not line:
|
||||
|
||||
@@ -7,9 +7,12 @@ forwards events as JSON to LOG_TARGET if set.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.interfaces import IReactorTCP
|
||||
from twisted.protocols.ftp import FTP, FTPFactory, FTPAnonymousShell
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python import log as twisted_log
|
||||
|
||||
@@ -95,7 +98,8 @@ _BAIT_PATH = _setup_bait_fs()
|
||||
|
||||
class ServerFTP(FTP):
|
||||
def connectionMade(self):
|
||||
peer = self.transport.getPeer()
|
||||
assert self.transport is not None
|
||||
peer = self.transport.getPeer() # type: ignore[misc]
|
||||
_log("connection", src_ip=peer.host, src_port=peer.port)
|
||||
super().connectionMade()
|
||||
|
||||
@@ -120,15 +124,16 @@ class ServerFTP(FTP):
|
||||
return defer.succeed((530, "Login incorrect."))
|
||||
self.state = self.AUTHED
|
||||
self._user = getattr(self, "_server_user", "anonymous")
|
||||
self.shell = FTPAnonymousShell(FilePath(_BAIT_PATH))
|
||||
self.shell = FTPAnonymousShell(FilePath(_BAIT_PATH)) # type: ignore[assignment]
|
||||
return defer.succeed((230, "Login successful."))
|
||||
|
||||
def ftp_RETR(self, path):
|
||||
_log("download_attempt", path=path)
|
||||
return super().ftp_RETR(path)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
peer = self.transport.getPeer()
|
||||
def connectionLost(self, reason: Failure) -> None: # type: ignore[override]
|
||||
assert self.transport is not None
|
||||
peer = self.transport.getPeer() # type: ignore[misc]
|
||||
_log("disconnect", src_ip=peer.host, src_port=peer.port)
|
||||
super().connectionLost(reason)
|
||||
|
||||
@@ -140,5 +145,5 @@ class ServerFTPFactory(FTPFactory):
|
||||
if __name__ == "__main__":
|
||||
twisted_log.startLoggingWithObserver(lambda e: None, setStdout=False)
|
||||
_log("startup", msg=f"FTP server starting as {NODE_NAME} on port {PORT}")
|
||||
reactor.listenTCP(PORT, ServerFTPFactory())
|
||||
reactor.run()
|
||||
cast(IReactorTCP, reactor).listenTCP(PORT, ServerFTPFactory()) # type: ignore[arg-type]
|
||||
reactor.run() # type: ignore[attr-defined]
|
||||
|
||||
@@ -17,6 +17,7 @@ import os
|
||||
import time
|
||||
from email.utils import getaddresses
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from syslog_bridge import (
|
||||
SEVERITY_WARNING,
|
||||
encode_secret,
|
||||
@@ -377,7 +378,7 @@ def _log(event_type: str, severity: int = 6, **kwargs) -> None:
|
||||
|
||||
def _parse_seq_range(range_str: str, total: int) -> list[int]:
|
||||
"""Parse IMAP sequence set ('1', '1:3', '1:*', '*') → list of 1-based indices."""
|
||||
result = []
|
||||
result: list[int] = []
|
||||
for part in range_str.split(","):
|
||||
part = part.strip()
|
||||
if ":" in part:
|
||||
@@ -472,6 +473,9 @@ def _build_fetch_response(seq: int, msg: dict, items: list[str]) -> bytes:
|
||||
# ── Protocol ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class IMAPProtocol(asyncio.Protocol):
|
||||
_transport: asyncio.Transport | None = None
|
||||
_peer: tuple[str, int]
|
||||
|
||||
def __init__(self):
|
||||
self._transport = None
|
||||
self._peer = ("?", 0)
|
||||
@@ -479,12 +483,12 @@ class IMAPProtocol(asyncio.Protocol):
|
||||
self._state = "NOT_AUTHENTICATED"
|
||||
self._selected = None # mailbox name currently selected
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._transport = transport
|
||||
self._peer = transport.get_extra_info("peername", ("?", 0))
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._peer = self._transport.get_extra_info("peername", ("?", 0))
|
||||
_log("connect", src=self._peer[0], src_port=self._peer[1])
|
||||
banner = IMAP_BANNER if IMAP_BANNER.endswith("\r\n") else IMAP_BANNER + "\r\n"
|
||||
transport.write(banner.encode())
|
||||
self._transport.write(banner.encode())
|
||||
|
||||
def data_received(self, data):
|
||||
self._buf += data
|
||||
@@ -519,6 +523,7 @@ class IMAPProtocol(asyncio.Protocol):
|
||||
elif cmd == "LOGOUT":
|
||||
self._w(b"* BYE Logging out\r\n")
|
||||
self._w(f"{tag} OK LOGOUT completed\r\n")
|
||||
assert self._transport is not None
|
||||
self._transport.close()
|
||||
|
||||
# NOT_AUTHENTICATED only
|
||||
@@ -638,6 +643,7 @@ class IMAPProtocol(asyncio.Protocol):
|
||||
if use_uid and "UID" not in items:
|
||||
items = ["UID"] + items
|
||||
|
||||
assert self._transport is not None
|
||||
for seq in indices:
|
||||
if 1 <= seq <= total:
|
||||
self._transport.write(_build_fetch_response(seq, emails[seq - 1], items))
|
||||
@@ -662,6 +668,7 @@ class IMAPProtocol(asyncio.Protocol):
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
def _w(self, data: str | bytes) -> None:
|
||||
assert self._transport is not None
|
||||
if isinstance(data, str):
|
||||
data = data.encode()
|
||||
self._transport.write(data)
|
||||
|
||||
@@ -8,6 +8,7 @@ invalidCredentials error. Logs all interactions as JSON.
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
import instance_seed as _seed
|
||||
from syslog_bridge import (
|
||||
@@ -137,14 +138,17 @@ def _bind_error_response(message_id: int, result_code: int = 49, error_text: str
|
||||
|
||||
|
||||
class LDAPProtocol(asyncio.Protocol):
|
||||
_transport: asyncio.Transport | None = None
|
||||
_peer: tuple[str, int] | None = None
|
||||
|
||||
def __init__(self):
|
||||
self._transport = None
|
||||
self._peer = None
|
||||
self._buf = b""
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._transport = transport
|
||||
self._peer = transport.get_extra_info("peername", ("?", 0))
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._peer = cast(tuple[str, int], self._transport.get_extra_info("peername", ("?", 0)))
|
||||
_log("connect", src=self._peer[0], src_port=self._peer[1])
|
||||
|
||||
def data_received(self, data):
|
||||
@@ -171,7 +175,9 @@ class LDAPProtocol(asyncio.Protocol):
|
||||
self._buf = self._buf[msg_len:]
|
||||
self._handle_message(msg)
|
||||
|
||||
def _handle_message(self, msg: bytes):
|
||||
def _handle_message(self, msg: bytes) -> None:
|
||||
assert self._transport is not None
|
||||
assert self._peer is not None
|
||||
# Extract messageID for the response
|
||||
try:
|
||||
message_id = msg[4] if len(msg) > 4 else 1
|
||||
|
||||
@@ -8,8 +8,10 @@ received messages as JSON.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import os
|
||||
import struct
|
||||
from typing import cast
|
||||
|
||||
import instance_seed as _seed
|
||||
from syslog_bridge import syslog_line, write_syslog_file, forward_syslog
|
||||
@@ -197,6 +199,9 @@ def _log(event_type: str, severity: int = 6, **kwargs) -> None:
|
||||
|
||||
|
||||
class MongoDBProtocol(asyncio.Protocol):
|
||||
_transport: asyncio.Transport | None = None
|
||||
_peer: tuple[str, int] | None = None
|
||||
|
||||
def __init__(self):
|
||||
self._transport = None
|
||||
self._peer = None
|
||||
@@ -207,12 +212,13 @@ class MongoDBProtocol(asyncio.Protocol):
|
||||
self._sasl_username: str | None = None
|
||||
self._sasl_mechanism: str | None = None
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._transport = transport
|
||||
self._peer = transport.get_extra_info("peername", ("?", 0))
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._peer = cast(tuple[str, int], self._transport.get_extra_info("peername", ("?", 0)))
|
||||
_log("connect", src=self._peer[0], src_port=self._peer[1])
|
||||
|
||||
def data_received(self, data):
|
||||
def data_received(self, data: bytes) -> None:
|
||||
assert self._transport is not None
|
||||
self._buf += data
|
||||
while len(self._buf) >= 16:
|
||||
msg_len = struct.unpack("<I", self._buf[:4])[0]
|
||||
@@ -226,7 +232,9 @@ class MongoDBProtocol(asyncio.Protocol):
|
||||
self._buf = self._buf[msg_len:]
|
||||
self._handle_message(msg)
|
||||
|
||||
def _handle_message(self, msg: bytes):
|
||||
def _handle_message(self, msg: bytes) -> None:
|
||||
assert self._transport is not None
|
||||
assert self._peer is not None
|
||||
if len(msg) < 16:
|
||||
return
|
||||
request_id = struct.unpack("<I", msg[4:8])[0]
|
||||
@@ -285,6 +293,7 @@ class MongoDBProtocol(asyncio.Protocol):
|
||||
self._transport.write(_op_reply(request_id, reply_doc))
|
||||
|
||||
def _handle_command(self, cmd: dict) -> None:
|
||||
assert self._peer is not None
|
||||
"""Parse a single MongoDB command document for SCRAM auth.
|
||||
|
||||
saslStart — client-first-message in payload. Extract
|
||||
@@ -318,7 +327,7 @@ class MongoDBProtocol(asyncio.Protocol):
|
||||
return
|
||||
try:
|
||||
proof_raw = base64.b64decode(proof_b64, validate=True)
|
||||
except (ValueError, base64.binascii.Error):
|
||||
except (ValueError, binascii.Error):
|
||||
return
|
||||
mech = (self._sasl_mechanism or "").upper()
|
||||
if "SHA-256" in mech or "SHA256" in mech:
|
||||
|
||||
@@ -12,6 +12,7 @@ import json
|
||||
import os
|
||||
import random
|
||||
import struct
|
||||
from typing import cast
|
||||
|
||||
import instance_seed as _seed
|
||||
from syslog_bridge import (
|
||||
@@ -209,6 +210,9 @@ def _generate_topics() -> dict:
|
||||
|
||||
|
||||
class MQTTProtocol(asyncio.Protocol):
|
||||
_transport: asyncio.Transport | None = None
|
||||
_peer: tuple[str, int] | None = None
|
||||
|
||||
def __init__(self):
|
||||
self._transport = None
|
||||
self._peer = None
|
||||
@@ -216,9 +220,9 @@ class MQTTProtocol(asyncio.Protocol):
|
||||
self._auth = False
|
||||
self._topics = _generate_topics()
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._transport = transport
|
||||
self._peer = transport.get_extra_info("peername", ("?", 0))
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._peer = cast(tuple[str, int], self._transport.get_extra_info("peername", ("?", 0)))
|
||||
_log("connect", src=self._peer[0], src_port=self._peer[1])
|
||||
|
||||
def data_received(self, data):
|
||||
@@ -231,6 +235,8 @@ class MQTTProtocol(asyncio.Protocol):
|
||||
self._transport.close()
|
||||
|
||||
def _process(self):
|
||||
assert self._transport is not None
|
||||
assert self._peer is not None
|
||||
while len(self._buf) >= 2:
|
||||
pkt_byte = self._buf[0]
|
||||
pkt_type = (pkt_byte >> 4) & 0x0f
|
||||
|
||||
@@ -9,6 +9,7 @@ import asyncio
|
||||
import base64
|
||||
import os
|
||||
import struct
|
||||
from typing import cast
|
||||
|
||||
import instance_seed as _seed
|
||||
from syslog_bridge import syslog_line, write_syslog_file, forward_syslog
|
||||
@@ -108,18 +109,23 @@ def _tds_error_packet(message: str) -> bytes:
|
||||
|
||||
|
||||
class MSSQLProtocol(asyncio.Protocol):
|
||||
_transport: asyncio.Transport | None = None
|
||||
_peer: tuple[str, int] | None = None
|
||||
|
||||
def __init__(self):
|
||||
self._transport = None
|
||||
self._peer = None
|
||||
self._buf = b""
|
||||
self._prelogin_done = False
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._transport = transport
|
||||
self._peer = transport.get_extra_info("peername", ("?", 0))
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._peer = cast(tuple[str, int], self._transport.get_extra_info("peername", ("?", 0)))
|
||||
_log("connect", src=self._peer[0], src_port=self._peer[1])
|
||||
|
||||
def data_received(self, data):
|
||||
def data_received(self, data: bytes) -> None:
|
||||
assert self._transport is not None
|
||||
assert self._peer is not None
|
||||
self._buf += data
|
||||
while len(self._buf) >= 8:
|
||||
pkt_type = self._buf[0]
|
||||
@@ -138,7 +144,9 @@ class MSSQLProtocol(asyncio.Protocol):
|
||||
self._buf = b""
|
||||
break
|
||||
|
||||
def _handle_packet(self, pkt_type: int, payload: bytes):
|
||||
def _handle_packet(self, pkt_type: int, payload: bytes) -> None:
|
||||
assert self._transport is not None
|
||||
assert self._peer is not None
|
||||
if pkt_type == 0x12: # Pre-login
|
||||
self._transport.write(_PRELOGIN_RESP)
|
||||
self._prelogin_done = True
|
||||
|
||||
@@ -11,6 +11,7 @@ import base64
|
||||
import itertools
|
||||
import os
|
||||
import struct
|
||||
from typing import cast
|
||||
|
||||
import instance_seed as _seed
|
||||
from syslog_bridge import syslog_line, write_syslog_file, forward_syslog
|
||||
@@ -74,6 +75,9 @@ def _log(event_type: str, severity: int = 6, **kwargs) -> None:
|
||||
|
||||
|
||||
class MySQLProtocol(asyncio.Protocol):
|
||||
_transport: asyncio.Transport | None = None
|
||||
_peer: tuple[str, int] | None = None
|
||||
|
||||
def __init__(self):
|
||||
self._transport = None
|
||||
self._peer = None
|
||||
@@ -84,15 +88,16 @@ class MySQLProtocol(asyncio.Protocol):
|
||||
# same decky never present identical auth challenges.
|
||||
self._salt = _seed.fresh_bytes(20)
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._transport = transport
|
||||
self._peer = transport.get_extra_info("peername", ("?", 0))
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._peer = cast(tuple[str, int], self._transport.get_extra_info("peername", ("?", 0)))
|
||||
_log("connect", src=self._peer[0], src_port=self._peer[1],
|
||||
connection_id=self._conn_id)
|
||||
transport.write(_make_packet(_build_greeting(self._conn_id, self._salt), seq=0))
|
||||
self._transport.write(_make_packet(_build_greeting(self._conn_id, self._salt), seq=0))
|
||||
self._greeted = True
|
||||
|
||||
def data_received(self, data):
|
||||
def data_received(self, data: bytes) -> None:
|
||||
assert self._transport is not None
|
||||
self._buf += data
|
||||
# MySQL packets: 3-byte length + 1-byte seq + payload
|
||||
while len(self._buf) >= 4:
|
||||
@@ -107,7 +112,8 @@ class MySQLProtocol(asyncio.Protocol):
|
||||
self._buf = self._buf[4 + length:]
|
||||
self._handle_packet(payload)
|
||||
|
||||
def _handle_packet(self, payload: bytes):
|
||||
def _handle_packet(self, payload: bytes) -> None:
|
||||
assert self._peer is not None
|
||||
if not payload:
|
||||
return
|
||||
# Login packet: capability flags (4), max_packet (4), charset (1),
|
||||
|
||||
@@ -13,6 +13,7 @@ import asyncio
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from syslog_bridge import (
|
||||
SEVERITY_WARNING,
|
||||
encode_secret,
|
||||
@@ -238,6 +239,9 @@ def _log(event_type: str, severity: int = 6, **kwargs) -> None:
|
||||
# ── Protocol ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class POP3Protocol(asyncio.Protocol):
|
||||
_transport: asyncio.Transport | None = None
|
||||
_peer: tuple[str, int]
|
||||
|
||||
def __init__(self):
|
||||
self._transport = None
|
||||
self._peer = ("?", 0)
|
||||
@@ -246,14 +250,14 @@ class POP3Protocol(asyncio.Protocol):
|
||||
self._current_user: str | None = None
|
||||
self._deleted: set[int] = set() # 0-based indices of DELE'd messages
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._transport = transport
|
||||
self._peer = transport.get_extra_info("peername", ("?", 0))
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._peer = self._transport.get_extra_info("peername", ("?", 0))
|
||||
_log("connect", src=self._peer[0], src_port=self._peer[1])
|
||||
banner = POP3_BANNER if POP3_BANNER.endswith("\r\n") else POP3_BANNER + "\r\n"
|
||||
if not banner.startswith("+OK"):
|
||||
banner = "+OK " + banner
|
||||
transport.write(banner.encode())
|
||||
self._transport.write(banner.encode())
|
||||
|
||||
def data_received(self, data):
|
||||
self._buf += data
|
||||
@@ -267,6 +271,7 @@ class POP3Protocol(asyncio.Protocol):
|
||||
# ── Command dispatch ──────────────────────────────────────────────────────
|
||||
|
||||
def _handle_line(self, line: str) -> None:
|
||||
assert self._transport is not None
|
||||
parts = line.split(None, 1)
|
||||
if not parts:
|
||||
return
|
||||
@@ -314,6 +319,7 @@ class POP3Protocol(asyncio.Protocol):
|
||||
# ── Command implementations ───────────────────────────────────────────────
|
||||
|
||||
def _cmd_user(self, args: str) -> None:
|
||||
assert self._transport is not None
|
||||
if self._state != "AUTHORIZATION":
|
||||
self._transport.write(b"-ERR Already authenticated\r\n")
|
||||
return
|
||||
@@ -321,6 +327,7 @@ class POP3Protocol(asyncio.Protocol):
|
||||
self._transport.write(b"+OK User name accepted, password please\r\n")
|
||||
|
||||
def _cmd_pass(self, args: str) -> None:
|
||||
assert self._transport is not None
|
||||
if self._state != "AUTHORIZATION":
|
||||
self._transport.write(b"-ERR Already authenticated\r\n")
|
||||
return
|
||||
@@ -342,6 +349,7 @@ class POP3Protocol(asyncio.Protocol):
|
||||
self._transport.write(b"-ERR Authentication failed.\r\n")
|
||||
|
||||
def _require_transaction(self) -> bool:
|
||||
assert self._transport is not None
|
||||
if self._state != "TRANSACTION":
|
||||
self._transport.write(b"-ERR Not authenticated\r\n")
|
||||
return False
|
||||
@@ -356,6 +364,7 @@ class POP3Protocol(asyncio.Protocol):
|
||||
]
|
||||
|
||||
def _cmd_stat(self) -> None:
|
||||
assert self._transport is not None
|
||||
if not self._require_transaction():
|
||||
return
|
||||
msgs = self._active_messages()
|
||||
@@ -363,6 +372,7 @@ class POP3Protocol(asyncio.Protocol):
|
||||
self._transport.write(f"+OK {len(msgs)} {total}\r\n".encode())
|
||||
|
||||
def _cmd_list(self, args: str) -> None:
|
||||
assert self._transport is not None
|
||||
if not self._require_transaction():
|
||||
return
|
||||
emails = _get_emails()
|
||||
@@ -386,6 +396,7 @@ class POP3Protocol(asyncio.Protocol):
|
||||
self._transport.write(b".\r\n")
|
||||
|
||||
def _cmd_retr(self, args: str) -> None:
|
||||
assert self._transport is not None
|
||||
if not self._require_transaction():
|
||||
return
|
||||
try:
|
||||
@@ -407,6 +418,7 @@ class POP3Protocol(asyncio.Protocol):
|
||||
self._transport.write(b"-ERR Invalid argument\r\n")
|
||||
|
||||
def _cmd_top(self, args: str) -> None:
|
||||
assert self._transport is not None
|
||||
if not self._require_transaction():
|
||||
return
|
||||
try:
|
||||
@@ -436,6 +448,7 @@ class POP3Protocol(asyncio.Protocol):
|
||||
self._transport.write(b"-ERR Invalid arguments\r\n")
|
||||
|
||||
def _cmd_uidl(self, args: str) -> None:
|
||||
assert self._transport is not None
|
||||
if not self._require_transaction():
|
||||
return
|
||||
if args:
|
||||
@@ -455,6 +468,7 @@ class POP3Protocol(asyncio.Protocol):
|
||||
self._transport.write(b".\r\n")
|
||||
|
||||
def _cmd_dele(self, args: str) -> None:
|
||||
assert self._transport is not None
|
||||
if not self._require_transaction():
|
||||
return
|
||||
try:
|
||||
@@ -470,6 +484,7 @@ class POP3Protocol(asyncio.Protocol):
|
||||
self._transport.write(b"-ERR Invalid argument\r\n")
|
||||
|
||||
def _cmd_rset(self) -> None:
|
||||
assert self._transport is not None
|
||||
if not self._require_transaction():
|
||||
return
|
||||
self._deleted.clear()
|
||||
|
||||
@@ -9,6 +9,7 @@ returns an error. Logs all interactions as JSON.
|
||||
import asyncio
|
||||
import os
|
||||
import struct
|
||||
from typing import cast
|
||||
|
||||
import instance_seed as _seed
|
||||
import base64 as _base64
|
||||
@@ -59,15 +60,18 @@ def _log(event_type: str, severity: int = 6, **kwargs) -> None:
|
||||
|
||||
|
||||
class PostgresProtocol(asyncio.Protocol):
|
||||
_transport: asyncio.Transport | None = None
|
||||
_peer: tuple[str, int] | None = None
|
||||
|
||||
def __init__(self):
|
||||
self._transport = None
|
||||
self._peer = None
|
||||
self._buf = b""
|
||||
self._state = "startup"
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._transport = transport
|
||||
self._peer = transport.get_extra_info("peername", ("?", 0))
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._peer = cast(tuple[str, int], self._transport.get_extra_info("peername", ("?", 0)))
|
||||
_log("connect", src=self._peer[0], src_port=self._peer[1])
|
||||
|
||||
def data_received(self, data):
|
||||
@@ -75,6 +79,7 @@ class PostgresProtocol(asyncio.Protocol):
|
||||
self._process()
|
||||
|
||||
def _process(self):
|
||||
assert self._transport is not None
|
||||
if self._state == "startup":
|
||||
if len(self._buf) < 4:
|
||||
return
|
||||
@@ -104,7 +109,9 @@ class PostgresProtocol(asyncio.Protocol):
|
||||
if msg_type == "p":
|
||||
self._handle_password(payload)
|
||||
|
||||
def _handle_startup(self, msg: bytes):
|
||||
def _handle_startup(self, msg: bytes) -> None:
|
||||
assert self._transport is not None
|
||||
assert self._peer is not None
|
||||
# Startup message: length(4) + protocol_version(4) + params (key=value\0 pairs)
|
||||
if len(msg) < 8:
|
||||
return
|
||||
@@ -128,8 +135,8 @@ class PostgresProtocol(asyncio.Protocol):
|
||||
# rejects *before* asking for a password. Short-circuit so the decoy
|
||||
# matches that behavior and exposes the per-decky DB list.
|
||||
if database and database not in _DATABASES:
|
||||
msg = f'database "{database}" does not exist'
|
||||
self._transport.write(_error_response("FATAL", "3D000", msg))
|
||||
err_msg = f'database "{database}" does not exist'
|
||||
self._transport.write(_error_response("FATAL", "3D000", err_msg))
|
||||
self._transport.close()
|
||||
return
|
||||
self._state = "auth"
|
||||
@@ -137,7 +144,9 @@ class PostgresProtocol(asyncio.Protocol):
|
||||
auth_md5 = b"R" + struct.pack(">I", 12) + struct.pack(">I", 5) + salt
|
||||
self._transport.write(auth_md5)
|
||||
|
||||
def _handle_password(self, payload: bytes):
|
||||
def _handle_password(self, payload: bytes) -> None:
|
||||
assert self._transport is not None
|
||||
assert self._peer is not None
|
||||
# Postgres MD5 challenge-response: the wire form is the literal
|
||||
# ASCII string "md5" + 32 hex chars (md5(md5(pw+user)+salt)).
|
||||
# Plaintext is unrecoverable, so we land this in the Credential
|
||||
|
||||
@@ -331,6 +331,7 @@ async def _upgrade_to_tls_and_capture(
|
||||
# into a StreamReader/StreamWriter pair the rest of the handler can use.
|
||||
new_reader = asyncio.StreamReader(loop=loop)
|
||||
new_protocol = asyncio.StreamReaderProtocol(new_reader, loop=loop)
|
||||
assert new_transport is not None
|
||||
new_transport.set_protocol(new_protocol)
|
||||
new_protocol.connection_made(new_transport)
|
||||
new_writer = asyncio.StreamWriter(new_transport, new_protocol, new_reader, loop)
|
||||
|
||||
@@ -7,6 +7,7 @@ KEYS, and arbitrary commands. Logs every command and argument as JSON.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
import instance_seed as _seed
|
||||
from syslog_bridge import (
|
||||
@@ -203,15 +204,18 @@ def _config_get(pattern: str) -> bytes:
|
||||
|
||||
|
||||
class RedisProtocol(asyncio.Protocol):
|
||||
_transport: asyncio.Transport | None = None
|
||||
_peer: tuple[str, int] | None = None
|
||||
|
||||
def __init__(self):
|
||||
self._transport = None
|
||||
self._peer = None
|
||||
self._parser = RESPParser()
|
||||
self._authed = not _REQUIREPASS # auth satisfied iff no password set
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._transport = transport
|
||||
self._peer = transport.get_extra_info("peername", ("?", 0))
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._peer = cast(tuple[str, int], self._transport.get_extra_info("peername", ("?", 0)))
|
||||
_log("connect", src=self._peer[0], src_port=self._peer[1])
|
||||
|
||||
def data_received(self, data):
|
||||
@@ -228,7 +232,8 @@ class RedisProtocol(asyncio.Protocol):
|
||||
if self._transport and not self._transport.is_closing():
|
||||
self._transport.write(payload)
|
||||
|
||||
def _handle_command(self, parts):
|
||||
def _handle_command(self, parts) -> None:
|
||||
assert self._peer is not None
|
||||
if not parts:
|
||||
return
|
||||
verb = parts[0].upper()
|
||||
|
||||
@@ -8,6 +8,7 @@ Authorization header and call metadata, then responds with 401 Unauthorized.
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from typing import cast
|
||||
from syslog_bridge import (
|
||||
classify_authorization,
|
||||
forward_syslog,
|
||||
@@ -98,11 +99,13 @@ def _handle_message(data: bytes, src_addr) -> bytes | None:
|
||||
|
||||
|
||||
class SIPUDPProtocol(asyncio.DatagramProtocol):
|
||||
_transport: asyncio.DatagramTransport | None = None
|
||||
|
||||
def __init__(self):
|
||||
self._transport = None
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._transport = transport
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.DatagramTransport, transport)
|
||||
|
||||
def datagram_received(self, data, addr):
|
||||
response = _handle_message(data, addr)
|
||||
@@ -111,21 +114,24 @@ class SIPUDPProtocol(asyncio.DatagramProtocol):
|
||||
|
||||
|
||||
class SIPTCPProtocol(asyncio.Protocol):
|
||||
_transport: asyncio.Transport | None = None
|
||||
_peer: tuple[str, int] | None = None
|
||||
|
||||
def __init__(self):
|
||||
self._transport = None
|
||||
self._peer = None
|
||||
self._buf = b""
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._transport = transport
|
||||
self._peer = transport.get_extra_info("peername", ("?", 0))
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._peer = cast(tuple[str, int], self._transport.get_extra_info("peername", ("?", 0)))
|
||||
|
||||
def data_received(self, data):
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self._buf += data
|
||||
if b"\r\n\r\n" in self._buf or b"\n\n" in self._buf:
|
||||
response = _handle_message(self._buf, self._peer)
|
||||
self._buf = b""
|
||||
if response:
|
||||
if response and self._transport:
|
||||
self._transport.write(response)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
|
||||
@@ -30,6 +30,7 @@ from datetime import datetime, timezone
|
||||
from email import message_from_bytes
|
||||
from email.header import decode_header, make_header
|
||||
from email.message import Message
|
||||
from typing import cast
|
||||
|
||||
import instance_seed as _seed
|
||||
from syslog_bridge import (
|
||||
@@ -150,7 +151,8 @@ def _summarize_message(body: bytes, msg_id: str) -> dict:
|
||||
if not filename and "attachment" not in disposition:
|
||||
continue
|
||||
try:
|
||||
payload = part.get_payload(decode=True) or b""
|
||||
_raw = part.get_payload(decode=True) or b""
|
||||
payload: bytes = _raw if isinstance(_raw, bytes) else b""
|
||||
except Exception:
|
||||
payload = b""
|
||||
attachments.append({
|
||||
@@ -207,6 +209,9 @@ def _decode_auth_plain(blob: str) -> tuple[str, str]:
|
||||
|
||||
|
||||
class SMTPProtocol(asyncio.Protocol):
|
||||
_transport: asyncio.Transport | None = None
|
||||
_peer: tuple[str, int]
|
||||
|
||||
def __init__(self):
|
||||
self._transport = None
|
||||
self._peer = ("?", 0)
|
||||
@@ -228,11 +233,11 @@ class SMTPProtocol(asyncio.Protocol):
|
||||
|
||||
# ── asyncio.Protocol ──────────────────────────────────────────────────────
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._transport = transport
|
||||
self._peer = transport.get_extra_info("peername", ("?", 0))
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._peer = self._transport.get_extra_info("peername", ("?", 0))
|
||||
_log("connect", src=self._peer[0], src_port=self._peer[1])
|
||||
transport.write(f"{_SMTP_BANNER}\r\n".encode())
|
||||
self._transport.write(f"{_SMTP_BANNER}\r\n".encode())
|
||||
|
||||
def data_received(self, data):
|
||||
self._buf += data
|
||||
@@ -247,6 +252,7 @@ class SMTPProtocol(asyncio.Protocol):
|
||||
# ── Command dispatch ──────────────────────────────────────────────────────
|
||||
|
||||
def _handle_line(self, line: str) -> None:
|
||||
assert self._transport is not None
|
||||
# ── DATA body accumulation ────────────────────────────────────────────
|
||||
if self._in_data:
|
||||
if line == ".":
|
||||
@@ -444,6 +450,7 @@ class SMTPProtocol(asyncio.Protocol):
|
||||
# ── AUTH helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _handle_auth(self, args: str) -> None:
|
||||
assert self._transport is not None
|
||||
parts = args.split(None, 1)
|
||||
mech = parts[0].upper() if parts else ""
|
||||
initial = parts[1] if len(parts) > 1 else ""
|
||||
@@ -468,6 +475,7 @@ class SMTPProtocol(asyncio.Protocol):
|
||||
self._transport.write(b"504 5.5.4 Unrecognized authentication mechanism\r\n")
|
||||
|
||||
def _finish_auth(self, username: str, password: str) -> None:
|
||||
assert self._transport is not None
|
||||
_log("auth_attempt", src=self._peer[0],
|
||||
username=username, principal=username,
|
||||
severity=SEVERITY_WARNING, **encode_secret(password))
|
||||
|
||||
@@ -9,6 +9,7 @@ Logs all requests as JSON.
|
||||
import asyncio
|
||||
import os
|
||||
import struct
|
||||
from typing import cast
|
||||
from syslog_bridge import (
|
||||
encode_secret,
|
||||
forward_syslog,
|
||||
@@ -225,11 +226,13 @@ def _build_response(version: int, community: str, request_id: int, oids: list) -
|
||||
|
||||
|
||||
class SNMPProtocol(asyncio.DatagramProtocol):
|
||||
_transport: asyncio.DatagramTransport | None = None
|
||||
|
||||
def __init__(self):
|
||||
self._transport = None
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._transport = transport
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.DatagramTransport, transport)
|
||||
|
||||
def datagram_received(self, data, addr):
|
||||
try:
|
||||
@@ -244,7 +247,8 @@ class SNMPProtocol(asyncio.DatagramProtocol):
|
||||
principal=None, secret_kind="snmp_community",
|
||||
**encode_secret(community))
|
||||
response = _build_response(version, community, request_id, oids)
|
||||
self._transport.sendto(response, addr)
|
||||
if self._transport is not None:
|
||||
self._transport.sendto(response, addr)
|
||||
except Exception as e:
|
||||
_log("parse_error", severity=4, src=addr[0], error=str(e), data=data[:64].hex())
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ then responds with an error packet. Logs all requests as JSON.
|
||||
import asyncio
|
||||
import os
|
||||
import struct
|
||||
from typing import cast
|
||||
from syslog_bridge import syslog_line, write_syslog_file, forward_syslog
|
||||
|
||||
NODE_NAME = os.environ.get("NODE_NAME", "tftpserver")
|
||||
@@ -33,11 +34,13 @@ def _log(event_type: str, severity: int = 6, **kwargs) -> None:
|
||||
|
||||
|
||||
class TFTPProtocol(asyncio.DatagramProtocol):
|
||||
_transport: asyncio.DatagramTransport | None = None
|
||||
|
||||
def __init__(self):
|
||||
self._transport = None
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._transport = transport
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.DatagramTransport, transport)
|
||||
|
||||
def datagram_received(self, data: bytes, addr):
|
||||
if len(data) < 4:
|
||||
@@ -56,7 +59,8 @@ class TFTPProtocol(asyncio.DatagramProtocol):
|
||||
filename=filename,
|
||||
mode=mode,
|
||||
)
|
||||
self._transport.sendto(_error_pkt(2, "Access violation"), addr)
|
||||
if self._transport is not None:
|
||||
self._transport.sendto(_error_pkt(2, "Access violation"), addr)
|
||||
else:
|
||||
_log("unknown_opcode", src=addr[0], opcode=opcode, data=data[:32].hex())
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ failed". Logs the raw response for offline cracking.
|
||||
import asyncio
|
||||
import os
|
||||
import base64 as _base64
|
||||
from typing import cast
|
||||
from syslog_bridge import syslog_line, write_syslog_file, forward_syslog
|
||||
|
||||
NODE_NAME = os.environ.get("NODE_NAME", "desktop")
|
||||
@@ -26,24 +27,29 @@ def _log(event_type: str, severity: int = 6, **kwargs) -> None:
|
||||
|
||||
|
||||
class VNCProtocol(asyncio.Protocol):
|
||||
_transport: asyncio.Transport | None = None
|
||||
_peer: tuple[str, int] | None = None
|
||||
|
||||
def __init__(self):
|
||||
self._transport = None
|
||||
self._peer = None
|
||||
self._buf = b""
|
||||
self._state = "version"
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._transport = transport
|
||||
self._peer = transport.get_extra_info("peername", ("?", 0))
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._peer = cast(tuple[str, int], self._transport.get_extra_info("peername", ("?", 0)))
|
||||
_log("connect", src=self._peer[0], src_port=self._peer[1])
|
||||
# Send RFB version
|
||||
transport.write(b"RFB 003.008\n")
|
||||
self._transport.write(b"RFB 003.008\n")
|
||||
|
||||
def data_received(self, data):
|
||||
self._buf += data
|
||||
self._process()
|
||||
|
||||
def _process(self):
|
||||
def _process(self) -> None:
|
||||
assert self._transport is not None
|
||||
assert self._peer is not None
|
||||
if self._state == "version":
|
||||
if b"\n" not in self._buf:
|
||||
return
|
||||
|
||||
@@ -178,4 +178,14 @@ Enable `check_untyped_defs = true` as the final step once the repo is clean.
|
||||
- Remove 9 stale `# type: ignore` comments across logging, helpers, credentials
|
||||
- Fix `telemetry.py` overload `no-redef` + `misc`
|
||||
- Fix `logs.py` `datetime/str` operator errors and nullable PK comparison
|
||||
- [ ] Annotate `transport` in template servers + guard call sites (P3, ~100 errors)
|
||||
- [x] P3 — template servers now have 0 mypy errors (146 fixed):
|
||||
- Add `_transport: asyncio.Transport | None` class-level annotation + `cast()` in `connection_made` for 11 TCP Protocol files (pop3, smtp, mqtt, postgres, mssql, mongodb, imap, ldap, redis, mysql, sip, vnc)
|
||||
- Add `_transport: asyncio.DatagramTransport | None` for 2 UDP DatagramProtocol files (snmp, tftp) + SIPUDPProtocol
|
||||
- `assert self._transport is not None` guards in each method that directly accesses transport
|
||||
- Fix RDP `start_tls()` `Transport | None` narrowing with `assert new_transport is not None`
|
||||
- Fix FTP Twisted stubs: `assert self.transport is not None`, `# type: ignore[misc/override/arg-type/attr-defined]` for imprecise Twisted stubs, `IReactorTCP` cast for `listenTCP`
|
||||
- Fix conpot `proc.stdout is None` guard before iteration
|
||||
- Fix SMTP `get_payload(decode=True)` → explicit `bytes` narrowing
|
||||
- Fix postgres `_handle_startup` param-name shadowing (`msg` bytes → `err_msg` str)
|
||||
- Fix mongodb `base64.binascii.Error` → `import binascii; binascii.Error`
|
||||
- Fix imap `result: list[int] = []` var-annotated
|
||||
|
||||
Reference in New Issue
Block a user