# Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details. """ Test cases for L{twisted.protocols.basic}. """ from __future__ import division, absolute_import import sys import struct from io import BytesIO from zope.interface.verify import verifyObject from twisted.python.compat import _PY3, iterbytes from twisted.trial import unittest from twisted.protocols import basic from twisted.python import reflect from twisted.internet import protocol, task from twisted.internet.interfaces import IProducer from twisted.test import proto_helpers _PY3NEWSTYLESKIP = "All classes are new style on Python 3." class FlippingLineTester(basic.LineReceiver): """ A line receiver that flips between line and raw data modes after one byte. """ delimiter = b'\n' def __init__(self): self.lines = [] def lineReceived(self, line): """ Set the mode to raw. """ self.lines.append(line) self.setRawMode() def rawDataReceived(self, data): """ Set the mode back to line. """ self.setLineMode(data[1:]) class LineTester(basic.LineReceiver): """ A line receiver that parses data received and make actions on some tokens. @type delimiter: C{bytes} @ivar delimiter: character used between received lines. @type MAX_LENGTH: C{int} @ivar MAX_LENGTH: size of a line when C{lineLengthExceeded} will be called. @type clock: L{twisted.internet.task.Clock} @ivar clock: clock simulating reactor callLater. Pass it to constructor if you want to use the pause/rawpause functionalities. """ delimiter = b'\n' MAX_LENGTH = 64 def __init__(self, clock=None): """ If given, use a clock to make callLater calls. """ self.clock = clock def connectionMade(self): """ Create/clean data received on connection. """ self.received = [] def lineReceived(self, line): """ Receive line and make some action for some tokens: pause, rawpause, stop, len, produce, unproduce. """ self.received.append(line) if line == b'': self.setRawMode() elif line == b'pause': self.pauseProducing() self.clock.callLater(0, self.resumeProducing) elif line == b'rawpause': self.pauseProducing() self.setRawMode() self.received.append(b'') self.clock.callLater(0, self.resumeProducing) elif line == b'stop': self.stopProducing() elif line[:4] == b'len ': self.length = int(line[4:]) elif line.startswith(b'produce'): self.transport.registerProducer(self, False) elif line.startswith(b'unproduce'): self.transport.unregisterProducer() def rawDataReceived(self, data): """ Read raw data, until the quantity specified by a previous 'len' line is reached. """ data, rest = data[:self.length], data[self.length:] self.length = self.length - len(data) self.received[-1] = self.received[-1] + data if self.length == 0: self.setLineMode(rest) def lineLengthExceeded(self, line): """ Adjust line mode when long lines received. """ if len(line) > self.MAX_LENGTH + 1: self.setLineMode(line[self.MAX_LENGTH + 1:]) class LineOnlyTester(basic.LineOnlyReceiver): """ A buffering line only receiver. """ delimiter = b'\n' MAX_LENGTH = 64 def connectionMade(self): """ Create/clean data received on connection. """ self.received = [] def lineReceived(self, line): """ Save received data. """ self.received.append(line) class LineReceiverTests(unittest.SynchronousTestCase): """ Test L{twisted.protocols.basic.LineReceiver}, using the C{LineTester} wrapper. """ buffer = b'''\ len 10 0123456789len 5 1234 len 20 foo 123 0123456789 012345678len 0 foo 5 1234567890123456789012345678901234567890123456789012345678901234567890 len 1 a''' output = [b'len 10', b'0123456789', b'len 5', b'1234\n', b'len 20', b'foo 123', b'0123456789\n012345678', b'len 0', b'foo 5', b'', b'67890', b'len 1', b'a'] def test_buffer(self): """ Test buffering for different packet size, checking received matches expected data. """ for packet_size in range(1, 10): t = proto_helpers.StringIOWithoutClosing() a = LineTester() a.makeConnection(protocol.FileWrapper(t)) for i in range(len(self.buffer) // packet_size + 1): s = self.buffer[i * packet_size:(i + 1) * packet_size] a.dataReceived(s) self.assertEqual(self.output, a.received) pauseBuf = b'twiddle1\ntwiddle2\npause\ntwiddle3\n' pauseOutput1 = [b'twiddle1', b'twiddle2', b'pause'] pauseOutput2 = pauseOutput1 + [b'twiddle3'] def test_pausing(self): """ Test pause inside data receiving. It uses fake clock to see if pausing/resuming work. """ for packet_size in range(1, 10): t = proto_helpers.StringIOWithoutClosing() clock = task.Clock() a = LineTester(clock) a.makeConnection(protocol.FileWrapper(t)) for i in range(len(self.pauseBuf) // packet_size + 1): s = self.pauseBuf[i * packet_size:(i + 1) * packet_size] a.dataReceived(s) self.assertEqual(self.pauseOutput1, a.received) clock.advance(0) self.assertEqual(self.pauseOutput2, a.received) rawpauseBuf = b'twiddle1\ntwiddle2\nlen 5\nrawpause\n12345twiddle3\n' rawpauseOutput1 = [b'twiddle1', b'twiddle2', b'len 5', b'rawpause', b''] rawpauseOutput2 = [b'twiddle1', b'twiddle2', b'len 5', b'rawpause', b'12345', b'twiddle3'] def test_rawPausing(self): """ Test pause inside raw date receiving. """ for packet_size in range(1, 10): t = proto_helpers.StringIOWithoutClosing() clock = task.Clock() a = LineTester(clock) a.makeConnection(protocol.FileWrapper(t)) for i in range(len(self.rawpauseBuf) // packet_size + 1): s = self.rawpauseBuf[i * packet_size:(i + 1) * packet_size] a.dataReceived(s) self.assertEqual(self.rawpauseOutput1, a.received) clock.advance(0) self.assertEqual(self.rawpauseOutput2, a.received) stop_buf = b'twiddle1\ntwiddle2\nstop\nmore\nstuff\n' stop_output = [b'twiddle1', b'twiddle2', b'stop'] def test_stopProducing(self): """ Test stop inside producing. """ for packet_size in range(1, 10): t = proto_helpers.StringIOWithoutClosing() a = LineTester() a.makeConnection(protocol.FileWrapper(t)) for i in range(len(self.stop_buf) // packet_size + 1): s = self.stop_buf[i * packet_size:(i + 1) * packet_size] a.dataReceived(s) self.assertEqual(self.stop_output, a.received) def test_lineReceiverAsProducer(self): """ Test produce/unproduce in receiving. """ a = LineTester() t = proto_helpers.StringIOWithoutClosing() a.makeConnection(protocol.FileWrapper(t)) a.dataReceived(b'produce\nhello world\nunproduce\ngoodbye\n') self.assertEqual( a.received, [b'produce', b'hello world', b'unproduce', b'goodbye']) def test_clearLineBuffer(self): """ L{LineReceiver.clearLineBuffer} removes all buffered data and returns it as a C{bytes} and can be called from beneath C{dataReceived}. """ class ClearingReceiver(basic.LineReceiver): def lineReceived(self, line): self.line = line self.rest = self.clearLineBuffer() protocol = ClearingReceiver() protocol.dataReceived(b'foo\r\nbar\r\nbaz') self.assertEqual(protocol.line, b'foo') self.assertEqual(protocol.rest, b'bar\r\nbaz') # Deliver another line to make sure the previously buffered data is # really gone. protocol.dataReceived(b'quux\r\n') self.assertEqual(protocol.line, b'quux') self.assertEqual(protocol.rest, b'') def test_stackRecursion(self): """ Test switching modes many times on the same data. """ proto = FlippingLineTester() transport = proto_helpers.StringIOWithoutClosing() proto.makeConnection(protocol.FileWrapper(transport)) limit = sys.getrecursionlimit() proto.dataReceived(b'x\nx' * limit) self.assertEqual(b'x' * limit, b''.join(proto.lines)) def test_maximumLineLength(self): """ C{LineReceiver} disconnects the transport if it receives a line longer than its C{MAX_LENGTH}. """ proto = basic.LineReceiver() transport = proto_helpers.StringTransport() proto.makeConnection(transport) proto.dataReceived(b'x' * (proto.MAX_LENGTH + 1) + b'\r\nr') self.assertTrue(transport.disconnecting) def test_maximumLineLengthPartialDelimiter(self): """ C{LineReceiver} doesn't disconnect the transport when it receives a finished line as long as its C{MAX_LENGTH}, when the second-to-last packet ended with a pattern that could have been -- and turns out to have been -- the start of a delimiter, and that packet causes the total input to exceed C{MAX_LENGTH} + len(delimiter). """ proto = LineTester() proto.MAX_LENGTH = 4 t = proto_helpers.StringTransport() proto.makeConnection(t) line = b'x' * (proto.MAX_LENGTH - 1) proto.dataReceived(line) proto.dataReceived(proto.delimiter[:-1]) proto.dataReceived(proto.delimiter[-1:] + line) self.assertFalse(t.disconnecting) self.assertEqual(len(proto.received), 1) self.assertEqual(line, proto.received[0]) def test_notQuiteMaximumLineLengthUnfinished(self): """ C{LineReceiver} doesn't disconnect the transport it if receives a non-finished line whose length, counting the delimiter, is longer than its C{MAX_LENGTH} but shorter than its C{MAX_LENGTH} + len(delimiter). (When the first part that exceeds the max is the beginning of the delimiter.) """ proto = basic.LineReceiver() # '\r\n' is the default, but we set it just to be explicit in # this test. proto.delimiter = b'\r\n' transport = proto_helpers.StringTransport() proto.makeConnection(transport) proto.dataReceived((b'x' * proto.MAX_LENGTH) + proto.delimiter[:len(proto.delimiter)-1]) self.assertFalse(transport.disconnecting) def test_rawDataError(self): """ C{LineReceiver.dataReceived} forwards errors returned by C{rawDataReceived}. """ proto = basic.LineReceiver() proto.rawDataReceived = lambda data: RuntimeError("oops") transport = proto_helpers.StringTransport() proto.makeConnection(transport) proto.setRawMode() why = proto.dataReceived(b'data') self.assertIsInstance(why, RuntimeError) def test_rawDataReceivedNotImplemented(self): """ When L{LineReceiver.rawDataReceived} is not overridden in a subclass, calling it raises C{NotImplementedError}. """ proto = basic.LineReceiver() self.assertRaises(NotImplementedError, proto.rawDataReceived, 'foo') def test_lineReceivedNotImplemented(self): """ When L{LineReceiver.lineReceived} is not overridden in a subclass, calling it raises C{NotImplementedError}. """ proto = basic.LineReceiver() self.assertRaises(NotImplementedError, proto.lineReceived, 'foo') class ExcessivelyLargeLineCatcher(basic.LineReceiver): """ Helper for L{LineReceiverLineLengthExceededTests}. @ivar longLines: A L{list} of L{bytes} giving the values C{lineLengthExceeded} has been called with. """ def connectionMade(self): self.longLines = [] def lineReceived(self, line): """ Disregard any received lines. """ def lineLengthExceeded(self, data): """ Record any data that exceeds the line length limits. """ self.longLines.append(data) class LineReceiverLineLengthExceededTests(unittest.SynchronousTestCase): """ Tests for L{twisted.protocols.basic.LineReceiver.lineLengthExceeded}. """ def setUp(self): self.proto = ExcessivelyLargeLineCatcher() self.proto.MAX_LENGTH = 6 self.transport = proto_helpers.StringTransport() self.proto.makeConnection(self.transport) def test_longUnendedLine(self): """ If more bytes than C{LineReceiver.MAX_LENGTH} arrive containing no line delimiter, all of the bytes are passed as a single string to L{LineReceiver.lineLengthExceeded}. """ excessive = b'x' * (self.proto.MAX_LENGTH * 2 + 2) self.proto.dataReceived(excessive) self.assertEqual([excessive], self.proto.longLines) def test_longLineAfterShortLine(self): """ If L{LineReceiver.dataReceived} is called with bytes representing a short line followed by bytes that exceed the length limit without a line delimiter, L{LineReceiver.lineLengthExceeded} is called with all of the bytes following the short line's delimiter. """ excessive = b'x' * (self.proto.MAX_LENGTH * 2 + 2) self.proto.dataReceived(b'x' + self.proto.delimiter + excessive) self.assertEqual([excessive], self.proto.longLines) def test_longLineWithDelimiter(self): """ If L{LineReceiver.dataReceived} is called with more than C{LineReceiver.MAX_LENGTH} bytes containing a line delimiter somewhere not in the first C{MAX_LENGTH} bytes, the entire byte string is passed to L{LineReceiver.lineLengthExceeded}. """ excessive = self.proto.delimiter.join( [b'x' * (self.proto.MAX_LENGTH * 2 + 2)] * 2) self.proto.dataReceived(excessive) self.assertEqual([excessive], self.proto.longLines) def test_multipleLongLines(self): """ If L{LineReceiver.dataReceived} is called with more than C{LineReceiver.MAX_LENGTH} bytes containing multiple line delimiters somewhere not in the first C{MAX_LENGTH} bytes, the entire byte string is passed to L{LineReceiver.lineLengthExceeded}. """ excessive = ( b'x' * (self.proto.MAX_LENGTH * 2 + 2) + self.proto.delimiter) * 2 self.proto.dataReceived(excessive) self.assertEqual([excessive], self.proto.longLines) def test_maximumLineLength(self): """ C{LineReceiver} disconnects the transport if it receives a line longer than its C{MAX_LENGTH}. """ proto = basic.LineReceiver() transport = proto_helpers.StringTransport() proto.makeConnection(transport) proto.dataReceived(b'x' * (proto.MAX_LENGTH + 1) + b'\r\nr') self.assertTrue(transport.disconnecting) def test_maximumLineLengthRemaining(self): """ C{LineReceiver} disconnects the transport it if receives a non-finished line longer than its C{MAX_LENGTH}. """ proto = basic.LineReceiver() transport = proto_helpers.StringTransport() proto.makeConnection(transport) proto.dataReceived(b'x' * (proto.MAX_LENGTH + len(proto.delimiter))) self.assertTrue(transport.disconnecting) class LineOnlyReceiverTests(unittest.SynchronousTestCase): """ Tests for L{twisted.protocols.basic.LineOnlyReceiver}. """ buffer = b"""foo bleakness desolation plastic forks """ def test_buffer(self): """ Test buffering over line protocol: data received should match buffer. """ t = proto_helpers.StringTransport() a = LineOnlyTester() a.makeConnection(t) for c in iterbytes(self.buffer): a.dataReceived(c) self.assertEqual(a.received, self.buffer.split(b'\n')[:-1]) def test_greaterThanMaximumLineLength(self): """ C{LineOnlyReceiver} disconnects the transport if it receives a line longer than its C{MAX_LENGTH} + len(delimiter). """ proto = LineOnlyTester() transport = proto_helpers.StringTransport() proto.makeConnection(transport) proto.dataReceived(b'x' * (proto.MAX_LENGTH + len(proto.delimiter) + 1) + b'\r\nr') self.assertTrue(transport.disconnecting) def test_lineReceivedNotImplemented(self): """ When L{LineOnlyReceiver.lineReceived} is not overridden in a subclass, calling it raises C{NotImplementedError}. """ proto = basic.LineOnlyReceiver() self.assertRaises(NotImplementedError, proto.lineReceived, 'foo') class TestMixin: def connectionMade(self): self.received = [] def stringReceived(self, s): self.received.append(s) MAX_LENGTH = 50 closed = 0 def connectionLost(self, reason): self.closed = 1 class TestNetstring(TestMixin, basic.NetstringReceiver): def stringReceived(self, s): self.received.append(s) self.transport.write(s) class LPTestCaseMixin: illegalStrings = [] protocol = None def getProtocol(self): """ Return a new instance of C{self.protocol} connected to a new instance of L{proto_helpers.StringTransport}. """ t = proto_helpers.StringTransport() a = self.protocol() a.makeConnection(t) return a def test_illegal(self): """ Assert that illegal strings cause the transport to be closed. """ for s in self.illegalStrings: r = self.getProtocol() for c in iterbytes(s): r.dataReceived(c) self.assertTrue(r.transport.disconnecting) class NetstringReceiverTests(unittest.SynchronousTestCase, LPTestCaseMixin): """ Tests for L{twisted.protocols.basic.NetstringReceiver}. """ strings = [b'hello', b'world', b'how', b'are', b'you123', b':today', b"a" * 515] illegalStrings = [ b'9999999999999999999999', b'abc', b'4:abcde', b'51:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab,',] protocol = TestNetstring def setUp(self): self.transport = proto_helpers.StringTransport() self.netstringReceiver = TestNetstring() self.netstringReceiver.makeConnection(self.transport) def test_buffer(self): """ Strings can be received in chunks of different lengths. """ for packet_size in range(1, 10): t = proto_helpers.StringTransport() a = TestNetstring() a.MAX_LENGTH = 699 a.makeConnection(t) for s in self.strings: a.sendString(s) out = t.value() for i in range(len(out) // packet_size + 1): s = out[i * packet_size:(i + 1) * packet_size] if s: a.dataReceived(s) self.assertEqual(a.received, self.strings) def test_receiveEmptyNetstring(self): """ Empty netstrings (with length '0') can be received. """ self.netstringReceiver.dataReceived(b"0:,") self.assertEqual(self.netstringReceiver.received, [b""]) def test_receiveOneCharacter(self): """ One-character netstrings can be received. """ self.netstringReceiver.dataReceived(b"1:a,") self.assertEqual(self.netstringReceiver.received, [b"a"]) def test_receiveTwoCharacters(self): """ Two-character netstrings can be received. """ self.netstringReceiver.dataReceived(b"2:ab,") self.assertEqual(self.netstringReceiver.received, [b"ab"]) def test_receiveNestedNetstring(self): """ Netstrings with embedded netstrings. This test makes sure that the parser does not become confused about the ',' and ':' characters appearing inside the data portion of the netstring. """ self.netstringReceiver.dataReceived(b"4:1:a,,") self.assertEqual(self.netstringReceiver.received, [b"1:a,"]) def test_moreDataThanSpecified(self): """ Netstrings containing more data than expected are refused. """ self.netstringReceiver.dataReceived(b"2:aaa,") self.assertTrue(self.transport.disconnecting) def test_moreDataThanSpecifiedBorderCase(self): """ Netstrings that should be empty according to their length specification are refused if they contain data. """ self.netstringReceiver.dataReceived(b"0:a,") self.assertTrue(self.transport.disconnecting) def test_missingNumber(self): """ Netstrings without leading digits that specify the length are refused. """ self.netstringReceiver.dataReceived(b":aaa,") self.assertTrue(self.transport.disconnecting) def test_missingColon(self): """ Netstrings without a colon between length specification and data are refused. """ self.netstringReceiver.dataReceived(b"3aaa,") self.assertTrue(self.transport.disconnecting) def test_missingNumberAndColon(self): """ Netstrings that have no leading digits nor a colon are refused. """ self.netstringReceiver.dataReceived(b"aaa,") self.assertTrue(self.transport.disconnecting) def test_onlyData(self): """ Netstrings consisting only of data are refused. """ self.netstringReceiver.dataReceived(b"aaa") self.assertTrue(self.transport.disconnecting) def test_receiveNetstringPortions_1(self): """ Netstrings can be received in two portions. """ self.netstringReceiver.dataReceived(b"4:aa") self.netstringReceiver.dataReceived(b"aa,") self.assertEqual(self.netstringReceiver.received, [b"aaaa"]) self.assertTrue(self.netstringReceiver._payloadComplete()) def test_receiveNetstringPortions_2(self): """ Netstrings can be received in more than two portions, even if the length specification is split across two portions. """ for part in [b"1", b"0:01234", b"56789", b","]: self.netstringReceiver.dataReceived(part) self.assertEqual(self.netstringReceiver.received, [b"0123456789"]) def test_receiveNetstringPortions_3(self): """ Netstrings can be received one character at a time. """ for part in [b"2", b":", b"a", b"b", b","]: self.netstringReceiver.dataReceived(part) self.assertEqual(self.netstringReceiver.received, [b"ab"]) def test_receiveTwoNetstrings(self): """ A stream of two netstrings can be received in two portions, where the first portion contains the complete first netstring and the length specification of the second netstring. """ self.netstringReceiver.dataReceived(b"1:a,1") self.assertTrue(self.netstringReceiver._payloadComplete()) self.assertEqual(self.netstringReceiver.received, [b"a"]) self.netstringReceiver.dataReceived(b":b,") self.assertEqual(self.netstringReceiver.received, [b"a", b"b"]) def test_maxReceiveLimit(self): """ Netstrings with a length specification exceeding the specified C{MAX_LENGTH} are refused. """ tooLong = self.netstringReceiver.MAX_LENGTH + 1 self.netstringReceiver.dataReceived(b"".join( (bytes(tooLong), b":", b"a" * tooLong))) self.assertTrue(self.transport.disconnecting) def test_consumeLength(self): """ C{_consumeLength} returns the expected length of the netstring, including the trailing comma. """ self.netstringReceiver._remainingData = b"12:" self.netstringReceiver._consumeLength() self.assertEqual(self.netstringReceiver._expectedPayloadSize, 13) def test_consumeLengthBorderCase1(self): """ C{_consumeLength} works as expected if the length specification contains the value of C{MAX_LENGTH} (border case). """ self.netstringReceiver._remainingData = b"12:" self.netstringReceiver.MAX_LENGTH = 12 self.netstringReceiver._consumeLength() self.assertEqual(self.netstringReceiver._expectedPayloadSize, 13) def test_consumeLengthBorderCase2(self): """ C{_consumeLength} raises a L{basic.NetstringParseError} if the length specification exceeds the value of C{MAX_LENGTH} by 1 (border case). """ self.netstringReceiver._remainingData = b"12:" self.netstringReceiver.MAX_LENGTH = 11 self.assertRaises(basic.NetstringParseError, self.netstringReceiver._consumeLength) def test_consumeLengthBorderCase3(self): """ C{_consumeLength} raises a L{basic.NetstringParseError} if the length specification exceeds the value of C{MAX_LENGTH} by more than 1. """ self.netstringReceiver._remainingData = b"1000:" self.netstringReceiver.MAX_LENGTH = 11 self.assertRaises(basic.NetstringParseError, self.netstringReceiver._consumeLength) def test_stringReceivedNotImplemented(self): """ When L{NetstringReceiver.stringReceived} is not overridden in a subclass, calling it raises C{NotImplementedError}. """ proto = basic.NetstringReceiver() self.assertRaises(NotImplementedError, proto.stringReceived, 'foo') class IntNTestCaseMixin(LPTestCaseMixin): """ TestCase mixin for int-prefixed protocols. """ protocol = None strings = None illegalStrings = None partialStrings = None def test_receive(self): """ Test receiving data find the same data send. """ r = self.getProtocol() for s in self.strings: for c in iterbytes(struct.pack(r.structFormat,len(s)) + s): r.dataReceived(c) self.assertEqual(r.received, self.strings) def test_partial(self): """ Send partial data, nothing should be definitely received. """ for s in self.partialStrings: r = self.getProtocol() for c in iterbytes(s): r.dataReceived(c) self.assertEqual(r.received, []) def test_send(self): """ Test sending data over protocol. """ r = self.getProtocol() r.sendString(b"b" * 16) self.assertEqual(r.transport.value(), struct.pack(r.structFormat, 16) + b"b" * 16) def test_lengthLimitExceeded(self): """ When a length prefix is received which is greater than the protocol's C{MAX_LENGTH} attribute, the C{lengthLimitExceeded} method is called with the received length prefix. """ length = [] r = self.getProtocol() r.lengthLimitExceeded = length.append r.MAX_LENGTH = 10 r.dataReceived(struct.pack(r.structFormat, 11)) self.assertEqual(length, [11]) def test_longStringNotDelivered(self): """ If a length prefix for a string longer than C{MAX_LENGTH} is delivered to C{dataReceived} at the same time as the entire string, the string is not passed to C{stringReceived}. """ r = self.getProtocol() r.MAX_LENGTH = 10 r.dataReceived( struct.pack(r.structFormat, 11) + b'x' * 11) self.assertEqual(r.received, []) def test_stringReceivedNotImplemented(self): """ When L{IntNStringReceiver.stringReceived} is not overridden in a subclass, calling it raises C{NotImplementedError}. """ proto = basic.IntNStringReceiver() self.assertRaises(NotImplementedError, proto.stringReceived, 'foo') class RecvdAttributeMixin(object): """ Mixin defining tests for string receiving protocols with a C{recvd} attribute which should be settable by application code, to be combined with L{IntNTestCaseMixin} on a L{TestCase} subclass """ def makeMessage(self, protocol, data): """ Return C{data} prefixed with message length in C{protocol.structFormat} form. """ return struct.pack(protocol.structFormat, len(data)) + data def test_recvdContainsRemainingData(self): """ In stringReceived, recvd contains the remaining data that was passed to dataReceived that was not part of the current message. """ result = [] r = self.getProtocol() def stringReceived(receivedString): result.append(r.recvd) r.stringReceived = stringReceived completeMessage = (struct.pack(r.structFormat, 5) + (b'a' * 5)) incompleteMessage = (struct.pack(r.structFormat, 5) + (b'b' * 4)) # Receive a complete message, followed by an incomplete one r.dataReceived(completeMessage + incompleteMessage) self.assertEqual(result, [incompleteMessage]) def test_recvdChanged(self): """ In stringReceived, if recvd is changed, messages should be parsed from it rather than the input to dataReceived. """ r = self.getProtocol() result = [] payloadC = b'c' * 5 messageC = self.makeMessage(r, payloadC) def stringReceived(receivedString): if not result: r.recvd = messageC result.append(receivedString) r.stringReceived = stringReceived payloadA = b'a' * 5 payloadB = b'b' * 5 messageA = self.makeMessage(r, payloadA) messageB = self.makeMessage(r, payloadB) r.dataReceived(messageA + messageB) self.assertEqual(result, [payloadA, payloadC]) def test_switching(self): """ Data already parsed by L{IntNStringReceiver.dataReceived} is not reparsed if C{stringReceived} consumes some of the L{IntNStringReceiver.recvd} buffer. """ proto = self.getProtocol() mix = [] SWITCH = b"\x00\x00\x00\x00" for s in self.strings: mix.append(self.makeMessage(proto, s)) mix.append(SWITCH) result = [] def stringReceived(receivedString): result.append(receivedString) proto.recvd = proto.recvd[len(SWITCH):] proto.stringReceived = stringReceived proto.dataReceived(b"".join(mix)) # Just another byte, to trigger processing of anything that might have # been left in the buffer (should be nothing). proto.dataReceived(b"\x01") self.assertEqual(result, self.strings) # And verify that another way self.assertEqual(proto.recvd, b"\x01") def test_recvdInLengthLimitExceeded(self): """ The L{IntNStringReceiver.recvd} buffer contains all data not yet processed by L{IntNStringReceiver.dataReceived} if the C{lengthLimitExceeded} event occurs. """ proto = self.getProtocol() DATA = b"too long" proto.MAX_LENGTH = len(DATA) - 1 message = self.makeMessage(proto, DATA) result = [] def lengthLimitExceeded(length): result.append(length) result.append(proto.recvd) proto.lengthLimitExceeded = lengthLimitExceeded proto.dataReceived(message) self.assertEqual(result[0], len(DATA)) self.assertEqual(result[1], message) class TestInt32(TestMixin, basic.Int32StringReceiver): """ A L{basic.Int32StringReceiver} storing received strings in an array. @ivar received: array holding received strings. """ class Int32Tests(unittest.SynchronousTestCase, IntNTestCaseMixin, RecvdAttributeMixin): """ Test case for int32-prefixed protocol """ protocol = TestInt32 strings = [b"a", b"b" * 16] illegalStrings = [b"\x10\x00\x00\x00aaaaaa"] partialStrings = [b"\x00\x00\x00", b"hello there", b""] def test_data(self): """ Test specific behavior of the 32-bits length. """ r = self.getProtocol() r.sendString(b"foo") self.assertEqual(r.transport.value(), b"\x00\x00\x00\x03foo") r.dataReceived(b"\x00\x00\x00\x04ubar") self.assertEqual(r.received, [b"ubar"]) class TestInt16(TestMixin, basic.Int16StringReceiver): """ A L{basic.Int16StringReceiver} storing received strings in an array. @ivar received: array holding received strings. """ class Int16Tests(unittest.SynchronousTestCase, IntNTestCaseMixin, RecvdAttributeMixin): """ Test case for int16-prefixed protocol """ protocol = TestInt16 strings = [b"a", b"b" * 16] illegalStrings = [b"\x10\x00aaaaaa"] partialStrings = [b"\x00", b"hello there", b""] def test_data(self): """ Test specific behavior of the 16-bits length. """ r = self.getProtocol() r.sendString(b"foo") self.assertEqual(r.transport.value(), b"\x00\x03foo") r.dataReceived(b"\x00\x04ubar") self.assertEqual(r.received, [b"ubar"]) def test_tooLongSend(self): """ Send too much data: that should cause an error. """ r = self.getProtocol() tooSend = b"b" * (2**(r.prefixLength * 8) + 1) self.assertRaises(AssertionError, r.sendString, tooSend) class NewStyleTestInt16(TestInt16, object): """ A new-style class version of TestInt16 """ class NewStyleInt16Tests(Int16Tests): """ This test case verifies that IntNStringReceiver still works when inherited by a new-style class. """ if _PY3: skip = _PY3NEWSTYLESKIP protocol = NewStyleTestInt16 class TestInt8(TestMixin, basic.Int8StringReceiver): """ A L{basic.Int8StringReceiver} storing received strings in an array. @ivar received: array holding received strings. """ class Int8Tests(unittest.SynchronousTestCase, IntNTestCaseMixin, RecvdAttributeMixin): """ Test case for int8-prefixed protocol """ protocol = TestInt8 strings = [b"a", b"b" * 16] illegalStrings = [b"\x00\x00aaaaaa"] partialStrings = [b"\x08", b"dzadz", b""] def test_data(self): """ Test specific behavior of the 8-bits length. """ r = self.getProtocol() r.sendString(b"foo") self.assertEqual(r.transport.value(), b"\x03foo") r.dataReceived(b"\x04ubar") self.assertEqual(r.received, [b"ubar"]) def test_tooLongSend(self): """ Send too much data: that should cause an error. """ r = self.getProtocol() tooSend = b"b" * (2**(r.prefixLength * 8) + 1) self.assertRaises(AssertionError, r.sendString, tooSend) class OnlyProducerTransport(object): """ Transport which isn't really a transport, just looks like one to someone not looking very hard. """ paused = False disconnecting = False def __init__(self): self.data = [] def pauseProducing(self): self.paused = True def resumeProducing(self): self.paused = False def write(self, bytes): self.data.append(bytes) class ConsumingProtocol(basic.LineReceiver): """ Protocol that really, really doesn't want any more bytes. """ def lineReceived(self, line): self.transport.write(line) self.pauseProducing() class ProducerTests(unittest.SynchronousTestCase): """ Tests for L{basic._PausableMixin} and L{basic.LineReceiver.paused}. """ def test_pauseResume(self): """ When L{basic.LineReceiver} is paused, it doesn't deliver lines to L{basic.LineReceiver.lineReceived} and delivers them immediately upon being resumed. L{ConsumingProtocol} is a L{LineReceiver} that pauses itself after every line, and writes that line to its transport. """ p = ConsumingProtocol() t = OnlyProducerTransport() p.makeConnection(t) # Deliver a partial line. # This doesn't trigger a pause and doesn't deliver a line. p.dataReceived(b'hello, ') self.assertEqual(t.data, []) self.assertFalse(t.paused) self.assertFalse(p.paused) # Deliver the rest of the line. # This triggers the pause, and the line is echoed. p.dataReceived(b'world\r\n') self.assertEqual(t.data, [b'hello, world']) self.assertTrue(t.paused) self.assertTrue(p.paused) # Unpausing doesn't deliver more data, and the protocol is unpaused. p.resumeProducing() self.assertEqual(t.data, [b'hello, world']) self.assertFalse(t.paused) self.assertFalse(p.paused) # Deliver two lines at once. # The protocol is paused after receiving and echoing the first line. p.dataReceived(b'hello\r\nworld\r\n') self.assertEqual(t.data, [b'hello, world', b'hello']) self.assertTrue(t.paused) self.assertTrue(p.paused) # Unpausing delivers the waiting line, and causes the protocol to # pause again. p.resumeProducing() self.assertEqual(t.data, [b'hello, world', b'hello', b'world']) self.assertTrue(t.paused) self.assertTrue(p.paused) # Deliver a line while paused. # This doesn't have a visible effect. p.dataReceived(b'goodbye\r\n') self.assertEqual(t.data, [b'hello, world', b'hello', b'world']) self.assertTrue(t.paused) self.assertTrue(p.paused) # Unpausing delivers the waiting line, and causes the protocol to # pause again. p.resumeProducing() self.assertEqual( t.data, [b'hello, world', b'hello', b'world', b'goodbye']) self.assertTrue(t.paused) self.assertTrue(p.paused) # Unpausing doesn't deliver more data, and the protocol is unpaused. p.resumeProducing() self.assertEqual( t.data, [b'hello, world', b'hello', b'world', b'goodbye']) self.assertFalse(t.paused) self.assertFalse(p.paused) class FileSenderTests(unittest.TestCase): """ Tests for L{basic.FileSender}. """ def test_interface(self): """ L{basic.FileSender} implements the L{IPullProducer} interface. """ sender = basic.FileSender() self.assertTrue(verifyObject(IProducer, sender)) def test_producerRegistered(self): """ When L{basic.FileSender.beginFileTransfer} is called, it registers itself with provided consumer, as a non-streaming producer. """ source = BytesIO(b"Test content") consumer = proto_helpers.StringTransport() sender = basic.FileSender() sender.beginFileTransfer(source, consumer) self.assertEqual(consumer.producer, sender) self.assertFalse(consumer.streaming) def test_transfer(self): """ L{basic.FileSender} sends the content of the given file using a C{IConsumer} interface via C{beginFileTransfer}. It returns a L{Deferred} which fires with the last byte sent. """ source = BytesIO(b"Test content") consumer = proto_helpers.StringTransport() sender = basic.FileSender() d = sender.beginFileTransfer(source, consumer) sender.resumeProducing() # resumeProducing only finishes after trying to read at eof sender.resumeProducing() self.assertIsNone(consumer.producer) self.assertEqual(b"t", self.successResultOf(d)) self.assertEqual(b"Test content", consumer.value()) def test_transferMultipleChunks(self): """ L{basic.FileSender} reads at most C{CHUNK_SIZE} every time it resumes producing. """ source = BytesIO(b"Test content") consumer = proto_helpers.StringTransport() sender = basic.FileSender() sender.CHUNK_SIZE = 4 d = sender.beginFileTransfer(source, consumer) # Ideally we would assertNoResult(d) here, but sender.resumeProducing() self.assertEqual(b"Test", consumer.value()) sender.resumeProducing() self.assertEqual(b"Test con", consumer.value()) sender.resumeProducing() self.assertEqual(b"Test content", consumer.value()) # resumeProducing only finishes after trying to read at eof sender.resumeProducing() self.assertEqual(b"t", self.successResultOf(d)) self.assertEqual(b"Test content", consumer.value()) def test_transferWithTransform(self): """ L{basic.FileSender.beginFileTransfer} takes a C{transform} argument which allows to manipulate the data on the fly. """ def transform(chunk): return chunk.swapcase() source = BytesIO(b"Test content") consumer = proto_helpers.StringTransport() sender = basic.FileSender() d = sender.beginFileTransfer(source, consumer, transform) sender.resumeProducing() # resumeProducing only finishes after trying to read at eof sender.resumeProducing() self.assertEqual(b"T", self.successResultOf(d)) self.assertEqual(b"tEST CONTENT", consumer.value()) def test_abortedTransfer(self): """ The C{Deferred} returned by L{basic.FileSender.beginFileTransfer} fails with an C{Exception} if C{stopProducing} when the transfer is not complete. """ source = BytesIO(b"Test content") consumer = proto_helpers.StringTransport() sender = basic.FileSender() d = sender.beginFileTransfer(source, consumer) # Abort the transfer right away sender.stopProducing() failure = self.failureResultOf(d) failure.trap(Exception) self.assertEqual("Consumer asked us to stop producing", str(failure.value)) class MiceDeprecationTests(unittest.TestCase): """ L{twisted.protocols.mice} is deprecated. """ if _PY3: skip = "twisted.protocols.mice is not being ported to Python 3." def test_MiceDeprecation(self): """ L{twisted.protocols.mice} is deprecated since Twisted 16.0. """ reflect.namedAny("twisted.protocols.mice") warningsShown = self.flushWarnings() self.assertEqual(1, len(warningsShown)) self.assertEqual( "twisted.protocols.mice was deprecated in Twisted 16.0.0: " "There is no replacement for this module.", warningsShown[0]['message'])