[Neo-report] r2295 vincent - in /trunk: neo/ neo/storage/ neo/storage/database/ neo/storag...
nobody at svn.erp5.org
nobody at svn.erp5.org
Sun Sep 5 11:15:04 CEST 2010
Author: vincent
Date: Sun Sep 5 11:15:04 2010
New Revision: 2295
Log:
Implement rsync-ish replication.
For further description, see storage/handlers/replication.py .
Added:
trunk/neo/tests/storage/testReplicationHandler.py
trunk/neo/tests/storage/testReplicator.py
Modified:
trunk/neo/handler.py
trunk/neo/protocol.py
trunk/neo/storage/app.py
trunk/neo/storage/database/manager.py
trunk/neo/storage/database/mysqldb.py
trunk/neo/storage/handlers/replication.py
trunk/neo/storage/handlers/storage.py
trunk/neo/storage/replicator.py
trunk/neo/tests/storage/testStorageHandler.py
trunk/neo/tests/storage/testStorageMySQLdb.py
trunk/neo/tests/testProtocol.py
trunk/tools/runner
Modified: trunk/neo/handler.py
==============================================================================
--- trunk/neo/handler.py [iso-8859-1] (original)
+++ trunk/neo/handler.py [iso-8859-1] Sun Sep 5 11:15:04 2010
@@ -275,16 +275,10 @@ class EventHandler(object):
def answerObjectHistory(self, conn, oid, history_list):
raise UnexpectedPacketError
- def askObjectHistoryFrom(self, conn, oid, min_serial, length):
+ def askObjectHistoryFrom(self, conn, oid, min_serial, length, partition):
raise UnexpectedPacketError
- def answerObjectHistoryFrom(self, conn, oid, history_list):
- raise UnexpectedPacketError
-
- def askOIDs(self, conn, min_oid, length, partition):
- raise UnexpectedPacketError
-
- def answerOIDs(self, conn, oid_list):
+ def answerObjectHistoryFrom(self, conn, object_dict):
raise UnexpectedPacketError
def askPartitionList(self, conn, min_offset, max_offset, uuid):
@@ -358,6 +352,21 @@ class EventHandler(object):
def answerPack(self, conn, status):
raise UnexpectedPacketError
+
+ def askCheckTIDRange(self, conn, min_tid, length, partition):
+ raise UnexpectedPacketError
+
+ def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum,
+ max_tid):
+ raise UnexpectedPacketError
+
+ def askCheckSerialRange(self, conn, min_oid, min_serial, length,
+ partition):
+ raise UnexpectedPacketError
+
+ def answerCheckSerialRange(self, conn, min_oid, min_serial, length, count,
+ oid_checksum, max_oid, serial_checksum, max_serial):
+ raise UnexpectedPacketError
# Error packet handlers.
@@ -450,8 +459,6 @@ class EventHandler(object):
d[Packets.AnswerObjectHistory] = self.answerObjectHistory
d[Packets.AskObjectHistoryFrom] = self.askObjectHistoryFrom
d[Packets.AnswerObjectHistoryFrom] = self.answerObjectHistoryFrom
- d[Packets.AskOIDs] = self.askOIDs
- d[Packets.AnswerOIDs] = self.answerOIDs
d[Packets.AskPartitionList] = self.askPartitionList
d[Packets.AnswerPartitionList] = self.answerPartitionList
d[Packets.AskNodeList] = self.askNodeList
@@ -476,6 +483,10 @@ class EventHandler(object):
d[Packets.AnswerBarrier] = self.answerBarrier
d[Packets.AskPack] = self.askPack
d[Packets.AnswerPack] = self.answerPack
+ d[Packets.AskCheckTIDRange] = self.askCheckTIDRange
+ d[Packets.AnswerCheckTIDRange] = self.answerCheckTIDRange
+ d[Packets.AskCheckSerialRange] = self.askCheckSerialRange
+ d[Packets.AnswerCheckSerialRange] = self.answerCheckSerialRange
return d
Modified: trunk/neo/protocol.py
==============================================================================
--- trunk/neo/protocol.py [iso-8859-1] (original)
+++ trunk/neo/protocol.py [iso-8859-1] Sun Sep 5 11:15:04 2010
@@ -113,6 +113,7 @@ INVALID_PARTITION = 0xffffffff
ZERO_TID = '\0' * 8
ZERO_OID = '\0' * 8
OID_LEN = len(INVALID_OID)
+TID_LEN = len(INVALID_TID)
UUID_NAMESPACES = {
NodeTypes.STORAGE: 'S',
@@ -1167,63 +1168,47 @@ class AnswerObjectHistory(Packet):
class AskObjectHistoryFrom(Packet):
"""
Ask history information for a given object. The order of serials is
- ascending, and starts at (or above) min_serial. S -> S.
+ ascending, and starts at (or above) min_serial for min_oid. S -> S.
"""
- _header_format = '!8s8sL'
+ _header_format = '!8s8sLL'
- def _encode(self, oid, min_serial, length):
- return pack(self._header_format, oid, min_serial, length)
+ def _encode(self, min_oid, min_serial, length, partition):
+ return pack(self._header_format, min_oid, min_serial, length,
+ partition)
def _decode(self, body):
- return unpack(self._header_format, body) # oid, min_serial, length
+ # min_oid, min_serial, length, partition
+ return unpack(self._header_format, body)
-class AnswerObjectHistoryFrom(AskFinishTransaction):
+class AnswerObjectHistoryFrom(Packet):
"""
Answer the requested serials. S -> S.
"""
- # This is similar to AskFinishTransaction as TID size is identical to OID
- # size:
- # - we have a single OID (TID in AskFinishTransaction)
- # - we have a list of TIDs (OIDs in AskFinishTransaction)
- pass
-
-class AskOIDs(Packet):
- """
- Ask for length OIDs starting at min_oid. S -> S.
- """
- _header_format = '!8sLL'
-
- def _encode(self, min_oid, length, partition):
- return pack(self._header_format, min_oid, length, partition)
-
- def _decode(self, body):
- return unpack(self._header_format, body) # min_oid, length, partition
-
-class AnswerOIDs(Packet):
- """
- Answer the requested OIDs. S -> S.
- """
_header_format = '!L'
- _list_entry_format = '8s'
+ _list_entry_format = '!8sL'
_list_entry_len = calcsize(_list_entry_format)
- def _encode(self, oid_list):
- body = [pack(self._header_format, len(oid_list))]
- body.extend(oid_list)
+ def _encode(self, object_dict):
+ body = [pack(self._header_format, len(object_dict))]
+ append = body.append
+ extend = body.extend
+ list_entry_format = self._list_entry_format
+ for oid, serial_list in object_dict.iteritems():
+ append(pack(list_entry_format, oid, len(serial_list)))
+ extend(serial_list)
return ''.join(body)
def _decode(self, body):
- offset = self._header_len
- (n,) = unpack(self._header_format, body[:offset])
- oid_list = []
+ body = StringIO(body)
+ read = body.read
list_entry_format = self._list_entry_format
list_entry_len = self._list_entry_len
- for _ in xrange(n):
- next_offset = offset + list_entry_len
- oid = unpack(list_entry_format, body[offset:next_offset])[0]
- offset = next_offset
- oid_list.append(oid)
- return (oid_list,)
+ object_dict = {}
+ dict_len = unpack(self._header_format, read(self._header_len))[0]
+ for _ in xrange(dict_len):
+ oid, serial_len = unpack(list_entry_format, read(list_entry_len))
+ object_dict[oid] = [read(TID_LEN) for _ in xrange(serial_len)]
+ return (object_dict, )
class AskPartitionList(Packet):
"""
@@ -1660,6 +1645,73 @@ class AnswerPack(Packet):
def _decode(self, body):
return (bool(unpack(self._header_format, body)[0]), )
+class AskCheckTIDRange(Packet):
+ """
+ Ask some stats about a range of transactions.
+ Used to know if there are differences between a replicating node and
+ reference node.
+ S -> S
+ """
+ _header_format = '!8sLL'
+
+ def _encode(self, min_tid, length, partition):
+ return pack(self._header_format, min_tid, length, partition)
+
+ def _decode(self, body):
+ return unpack(self._header_format, body) # min_tid, length, partition
+
+class AnswerCheckTIDRange(Packet):
+ """
+ Stats about a range of transactions.
+ Used to know if there are differences between a replicating node and
+ reference node.
+ S -> S
+ """
+ _header_format = '!8sLLQ8s'
+ def _encode(self, min_tid, length, count, tid_checksum, max_tid):
+ return pack(self._header_format, min_tid, length, count, tid_checksum,
+ max_tid)
+
+ def _decode(self, body):
+ # min_tid, length, partition, count, tid_checksum, max_tid
+ return unpack(self._header_format, body)
+
+class AskCheckSerialRange(Packet):
+ """
+ Ask some stats about a range of object history.
+ Used to know if there are differences between a replicating node and
+ reference node.
+ S -> S
+ """
+ _header_format = '!8s8sLL'
+
+ def _encode(self, min_oid, min_serial, length, partition):
+ return pack(self._header_format, min_oid, min_serial, length,
+ partition)
+
+ def _decode(self, body):
+ # min_oid, min_serial, length, partition
+ return unpack(self._header_format, body)
+
+class AnswerCheckSerialRange(Packet):
+ """
+ Stats about a range of object history.
+ Used to know if there are differences between a replicating node and
+ reference node.
+ S -> S
+ """
+ _header_format = '!8s8sLLQ8sQ8s'
+
+ def _encode(self, min_oid, min_serial, length, count, oid_checksum,
+ max_oid, serial_checksum, max_serial):
+ return pack(self._header_format, min_oid, min_serial, length, count,
+ oid_checksum, max_oid, serial_checksum, max_serial)
+
+ def _decode(self, body):
+ # min_oid, min_serial, length, count, oid_checksum, max_oid,
+ # serial_checksum, max_serial
+ return unpack(self._header_format, body)
+
class Error(Packet):
"""
Error is a special type of message, because this can be sent against
@@ -1844,10 +1896,6 @@ class PacketRegistry(dict):
0x001F,
AskObjectHistory,
AnswerObjectHistory)
- AskOIDs, AnswerOIDs = register(
- 0x0020,
- AskOIDs,
- AnswerOIDs)
AskPartitionList, AnswerPartitionList = register(
0x0021,
AskPartitionList,
@@ -1903,6 +1951,16 @@ class PacketRegistry(dict):
0x0038,
AskPack,
AnswerPack)
+ AskCheckTIDRange, AnswerCheckTIDRange = register(
+ 0x0039,
+ AskCheckTIDRange,
+ AnswerCheckTIDRange,
+ )
+ AskCheckSerialRange, AnswerCheckSerialRange = register(
+ 0x003A,
+ AskCheckSerialRange,
+ AnswerCheckSerialRange,
+ )
# build a "singleton"
Packets = PacketRegistry()
Modified: trunk/neo/storage/app.py
==============================================================================
--- trunk/neo/storage/app.py [iso-8859-1] (original)
+++ trunk/neo/storage/app.py [iso-8859-1] Sun Sep 5 11:15:04 2010
@@ -288,6 +288,12 @@ class Application(object):
while True:
em.poll(1)
if self.replicator.pending():
+ # Call processDelayedTasks before act, so tasks added in the
+ # act call are executed after one poll call, so that sent
+ # packets are already on the network and delayed task
+ # processing happens in parallel with the same task on the
+ # other storage node.
+ self.replicator.processDelayedTasks()
self.replicator.act()
def wait(self):
Modified: trunk/neo/storage/database/manager.py
==============================================================================
--- trunk/neo/storage/database/manager.py [iso-8859-1] (original)
+++ trunk/neo/storage/database/manager.py [iso-8859-1] Sun Sep 5 11:15:04 2010
@@ -274,6 +274,11 @@ class DatabaseManager(object):
area."""
raise NotImplementedError
+ def deleteObject(self, oid, serial=None):
+ """Delete given object. If serial is given, only delete that serial for
+ given oid."""
+ raise NotImplementedError
+
def getTransaction(self, tid, all = False):
"""Return a tuple of the list of OIDs, user information,
a description, and extension information, for a given transaction
@@ -282,12 +287,6 @@ class DatabaseManager(object):
area as well."""
raise NotImplementedError
- def getOIDList(self, min_oid, length, num_partitions, partition_list):
- """Return a list of OIDs in ascending order from a minimal oid,
- at most the specified length. The list of partitions are passed
- to filter out non-applicable TIDs."""
- raise NotImplementedError
-
def getObjectHistory(self, oid, offset = 0, length = 1):
"""Return a list of serials and sizes for a given object ID.
The length specifies the maximum size of such a list. Result starts
@@ -295,9 +294,11 @@ class DatabaseManager(object):
If there is no such object ID in a database, return None."""
raise NotImplementedError
- def getObjectHistoryFrom(self, oid, min_serial, length):
- """Return a list of length serials for a given object ID at (or above)
- min_serial, sorted in ascending order."""
+ def getObjectHistoryFrom(self, oid, min_serial, length, num_partitions,
+ partition):
+ """Return a dict of length serials grouped by oid at (or above)
+ min_oid and min_serial, for given partition, sorted in ascending
+ order."""
raise NotImplementedError
def getTIDList(self, offset, length, num_partitions, partition_list):
@@ -307,20 +308,10 @@ class DatabaseManager(object):
raise NotImplementedError
def getReplicationTIDList(self, min_tid, length, num_partitions,
- partition_list):
+ partition):
"""Return a list of TIDs in ascending order from an initial tid value,
- at most the specified length. The list of partitions are passed
- to filter out non-applicable TIDs."""
- raise NotImplementedError
-
- def getTIDListPresent(self, tid_list):
- """Return a list of TIDs which are present in a database among
- the given list."""
- raise NotImplementedError
-
- def getSerialListPresent(self, oid, serial_list):
- """Return a list of serials which are present in a database among
- the given list."""
+ at most the specified length. The partition number is passed to filter
+ out non-applicable TIDs."""
raise NotImplementedError
def pack(self, tid, updateObjectDataForPack):
Modified: trunk/neo/storage/database/mysqldb.py
==============================================================================
--- trunk/neo/storage/database/mysqldb.py [iso-8859-1] (original)
+++ trunk/neo/storage/database/mysqldb.py [iso-8859-1] Sun Sep 5 11:15:04 2010
@@ -24,7 +24,7 @@ import string
from neo.storage.database import DatabaseManager
from neo.exception import DatabaseFailure
-from neo.protocol import CellStates
+from neo.protocol import CellStates, ZERO_OID, ZERO_TID
from neo import util
LOG_QUERIES = False
@@ -576,6 +576,23 @@ class MySQLDatabaseManager(DatabaseManag
raise
self.commit()
+ def deleteObject(self, oid, serial=None):
+ u64 = util.u64
+ query_param_dict = {
+ 'oid': u64(oid),
+ }
+ query_fmt = 'DELETE FROM obj WHERE oid = %(oid)d'
+ if serial is not None:
+ query_param_dict['serial'] = u64(serial)
+ query_fmt = query_fmt + ' AND serial = %(serial)d'
+ self.begin()
+ try:
+ self.query(query_fmt % query_param_dict)
+ except:
+ self.rollback()
+ raise
+ self.commit()
+
def getTransaction(self, tid, all = False):
q = self.query
tid = util.u64(tid)
@@ -594,20 +611,6 @@ class MySQLDatabaseManager(DatabaseManag
return oid_list, user, desc, ext, bool(packed)
return None
- def getOIDList(self, min_oid, length, num_partitions,
- partition_list):
- q = self.query
- r = q("""SELECT DISTINCT oid FROM obj WHERE
- MOD(oid, %(num_partitions)d) in (%(partitions)s)
- AND oid >= %(min_oid)d
- ORDER BY oid ASC LIMIT %(length)d""" % {
- 'num_partitions': num_partitions,
- 'partitions': ','.join([str(p) for p in partition_list]),
- 'min_oid': util.u64(min_oid),
- 'length': length,
- })
- return [util.p64(t[0]) for t in r]
-
def _getObjectLength(self, oid, value_serial):
if value_serial is None:
raise CreationUndone
@@ -646,18 +649,32 @@ class MySQLDatabaseManager(DatabaseManag
return result
return None
- def getObjectHistoryFrom(self, oid, min_serial, length):
+ def getObjectHistoryFrom(self, min_oid, min_serial, length, num_partitions,
+ partition):
q = self.query
- oid = util.u64(oid)
+ u64 = util.u64
p64 = util.p64
- r = q("""SELECT serial FROM obj
- WHERE oid = %(oid)d AND serial >= %(min_serial)d
- ORDER BY serial ASC LIMIT %(length)d""" % {
- 'oid': oid,
- 'min_serial': util.u64(min_serial),
+ min_oid = u64(min_oid)
+ min_serial = u64(min_serial)
+ r = q('SELECT oid, serial FROM obj '
+ 'WHERE ((oid = %(min_oid)d AND serial >= %(min_serial)d) OR '
+ 'oid > %(min_oid)d) AND '
+ 'MOD(oid, %(num_partitions)d) = %(partition)s '
+ 'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
+ 'min_oid': min_oid,
+ 'min_serial': min_serial,
'length': length,
+ 'num_partitions': num_partitions,
+ 'partition': partition,
})
- return [p64(t[0]) for t in r]
+ result = {}
+ for oid, serial in r:
+ try:
+ serial_list = result[oid]
+ except KeyError:
+ serial_list = result[oid] = []
+ serial_list.append(p64(serial))
+ return dict((p64(x), y) for x, y in result.iteritems())
def getTIDList(self, offset, length, num_partitions, partition_list):
q = self.query
@@ -669,32 +686,19 @@ class MySQLDatabaseManager(DatabaseManag
return [util.p64(t[0]) for t in r]
def getReplicationTIDList(self, min_tid, length, num_partitions,
- partition_list):
+ partition):
q = self.query
r = q("""SELECT tid FROM trans WHERE
- MOD(tid, %(num_partitions)d) in (%(partitions)s)
+ MOD(tid, %(num_partitions)d) = %(partition)d
AND tid >= %(min_tid)d
ORDER BY tid ASC LIMIT %(length)d""" % {
'num_partitions': num_partitions,
- 'partitions': ','.join([str(p) for p in partition_list]),
+ 'partition': partition,
'min_tid': util.u64(min_tid),
'length': length,
})
return [util.p64(t[0]) for t in r]
- def getTIDListPresent(self, tid_list):
- q = self.query
- r = q("""SELECT tid FROM trans WHERE tid in (%s)""" \
- % ','.join([str(util.u64(tid)) for tid in tid_list]))
- return [util.p64(t[0]) for t in r]
-
- def getSerialListPresent(self, oid, serial_list):
- q = self.query
- oid = util.u64(oid)
- r = q("""SELECT serial FROM obj WHERE oid = %d AND serial in (%s)""" \
- % (oid, ','.join([str(util.u64(serial)) for serial in serial_list])))
- return [util.p64(t[0]) for t in r]
-
def _updatePackFuture(self, oid, orig_serial, max_serial,
updateObjectDataForPack):
q = self.query
@@ -783,4 +787,54 @@ class MySQLDatabaseManager(DatabaseManag
self.rollback()
raise
self.commit()
+
+ def checkTIDRange(self, min_tid, length, num_partitions, partition):
+ # XXX: XOR is a lame checksum
+ count, tid_checksum, max_tid = self.query('SELECT COUNT(*), '
+ 'BIT_XOR(tid), MAX(tid) FROM ('
+ 'SELECT tid FROM trans '
+ 'WHERE MOD(tid, %(num_partitions)d) = %(partition)s '
+ 'AND tid >= %(min_tid)d '
+ 'ORDER BY tid ASC LIMIT %(length)d'
+ ') AS foo' % {
+ 'num_partitions': num_partitions,
+ 'partition': partition,
+ 'min_tid': util.u64(min_tid),
+ 'length': length,
+ })[0]
+ if count == 0:
+ tid_checksum = 0
+ max_tid = ZERO_TID
+ else:
+ max_tid = util.p64(max_tid)
+ return count, tid_checksum, max_tid
+
+ def checkSerialRange(self, min_oid, min_serial, length, num_partitions,
+ partition):
+ # XXX: XOR is a lame checksum
+ u64 = util.u64
+ p64 = util.p64
+ r = self.query('SELECT oid, serial FROM obj WHERE '
+ '(oid > %(min_oid)d OR '
+ '(oid = %(min_oid)d AND serial >= %(min_serial)d)) '
+ 'AND MOD(oid, %(num_partitions)d) = %(partition)s '
+ 'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
+ 'min_oid': u64(min_oid),
+ 'min_serial': u64(min_serial),
+ 'length': length,
+ 'num_partitions': num_partitions,
+ 'partition': partition,
+ })
+ count = len(r)
+ oid_checksum = serial_checksum = 0
+ if count == 0:
+ max_oid = ZERO_OID
+ max_serial = ZERO_TID
+ else:
+ for max_oid, max_serial in r:
+ oid_checksum ^= max_oid
+ serial_checksum ^= max_serial
+ max_oid = p64(max_oid)
+ max_serial = p64(max_serial)
+ return count, oid_checksum, max_oid, serial_checksum, max_serial
Modified: trunk/neo/storage/handlers/replication.py
==============================================================================
--- trunk/neo/storage/handlers/replication.py [iso-8859-1] (original)
+++ trunk/neo/storage/handlers/replication.py [iso-8859-1] Sun Sep 5 11:15:04 2010
@@ -22,6 +22,48 @@ from neo.handler import EventHandler
from neo.protocol import Packets, ZERO_TID, ZERO_OID
from neo import util
+# TODO: benchmark how different values behave
+RANGE_LENGTH = 4000
+MIN_RANGE_LENGTH = 1000
+
+"""
+Replication algorythm
+
+Purpose: replicate the content of a reference node into a replicating node,
+bringing it up-to-date.
+This happens both when a new storage is added to en existing cluster, as well
+as when a nde was separated from cluster and rejoins it.
+
+Replication happens per partition. Reference node can change between
+partitions.
+
+2 parts, done sequentially:
+- Transaction (metadata) replication
+- Object (data) replication
+
+Both part follow the same mechanism:
+- On both sides (replicating and reference), compute a checksum of a chunk
+ (RANGE_LENGTH number of entries). If there is a mismatch, chunk size is
+ reduced, and scan restarts from same row, until it reaches a minimal length
+ (MIN_RANGE_LENGTH). Then, it replicates all rows in that chunk. If the
+ content of chunks match, it moves on to the next chunk.
+- Replicating a chunk starts with asking for a list of all entries (only their
+ identifier) and skipping those both side have, deleting those which reference
+ has and replicating doesn't, and asking individually all entries missing in
+ replicating.
+"""
+
+# TODO: Make object replication get ordered by serial first and oid second, so
+# changes are in a big segment at the end, rather than in many segments (one
+# per object).
+
+# TODO: To improve performance when a pack happened, the following algorithm
+# should be used:
+# - If reference node packed, find non-existant oids in reference node (their
+# creation was undone, and pack pruned them), and delete them.
+# - Run current algorithm, starting at our last pack TID.
+# - Pack partition at reference's TID.
+
def checkConnectionIsReplicatorConnection(func):
def decorator(self, conn, *args, **kw):
if self.app.replicator.current_connection is conn:
@@ -51,28 +93,26 @@ class ReplicationHandler(EventHandler):
uuid, num_partitions, num_replicas, your_uuid):
# set the UUID on the connection
conn.setUUID(uuid)
+ self.startReplication(conn)
+
+ def startReplication(self, conn):
+ conn.ask(self._doAskCheckTIDRange(ZERO_TID), timeout=300)
@checkConnectionIsReplicatorConnection
def answerTIDsFrom(self, conn, tid_list):
app = self.app
- if tid_list:
- # If I have pending TIDs, check which TIDs I don't have, and
- # request the data.
- present_tid_list = app.dm.getTIDListPresent(tid_list)
- tid_set = set(tid_list) - set(present_tid_list)
- for tid in tid_set:
- conn.ask(Packets.AskTransactionInformation(tid), timeout=300)
-
- # And, ask more TIDs.
- p = Packets.AskTIDsFrom(add64(tid_list[-1], 1), 1000,
- app.replicator.current_partition.getRID())
- conn.ask(p, timeout=300)
- else:
- # If no more TID, a replication of transactions is finished.
- # So start to replicate objects now.
- p = Packets.AskOIDs(ZERO_OID, 1000,
- app.replicator.current_partition.getRID())
- conn.ask(p, timeout=300)
+ # If I have pending TIDs, check which TIDs I don't have, and
+ # request the data.
+ tid_set = frozenset(tid_list)
+ my_tid_set = frozenset(app.replicator.getTIDsFromResult())
+ extra_tid_set = my_tid_set - tid_set
+ if extra_tid_set:
+ deleteTransaction = app.dm.deleteTransaction
+ for tid in extra_tid_set:
+ deleteTransaction(tid)
+ missing_tid_set = tid_set - my_tid_set
+ for tid in missing_tid_set:
+ conn.ask(Packets.AskTransactionInformation(tid), timeout=300)
@checkConnectionIsReplicatorConnection
def answerTransactionInformation(self, conn, tid,
@@ -83,46 +123,23 @@ class ReplicationHandler(EventHandler):
False)
@checkConnectionIsReplicatorConnection
- def answerOIDs(self, conn, oid_list):
+ def answerObjectHistoryFrom(self, conn, object_dict):
app = self.app
- if oid_list:
- app.replicator.next_oid = add64(oid_list[-1], 1)
- # Pick one up, and ask the history.
- oid = oid_list.pop()
- conn.ask(Packets.AskObjectHistoryFrom(oid, ZERO_TID, 1000),
- timeout=300)
- app.replicator.oid_list = oid_list
- else:
- # Nothing remains, so the replication for this partition is
- # finished.
- app.replicator.replication_done = True
-
- @checkConnectionIsReplicatorConnection
- def answerObjectHistoryFrom(self, conn, oid, serial_list):
- app = self.app
- if serial_list:
+ my_object_dict = app.replicator.getObjectHistoryFromResult()
+ deleteObject = app.dm.deleteObject
+ for oid, serial_list in object_dict.iteritems():
# Check if I have objects, request those which I don't have.
- present_serial_list = app.dm.getSerialListPresent(oid, serial_list)
- serial_set = set(serial_list) - set(present_serial_list)
- for serial in serial_set:
- conn.ask(Packets.AskObject(oid, serial, None), timeout=300)
-
- # And, ask more serials.
- conn.ask(Packets.AskObjectHistoryFrom(oid,
- add64(serial_list[-1], 1), 1000), timeout=300)
- else:
- # This OID is finished. So advance to next.
- oid_list = app.replicator.oid_list
- if oid_list:
- # If I have more pending OIDs, pick one up.
- oid = oid_list.pop()
- conn.ask(Packets.AskObjectHistoryFrom(oid, ZERO_TID, 1000),
- timeout=300)
+ if oid in my_object_dict:
+ my_serial_set = frozenset(my_object_dict[oid])
+ serial_set = frozenset(serial_list)
+ extra_serial_set = my_serial_set - serial_set
+ for serial in extra_serial_set:
+ deleteObject(oid, serial)
+ missing_serial_set = serial_set - my_serial_set
else:
- # Otherwise, acquire more OIDs.
- p = Packets.AskOIDs(app.replicator.next_oid, 1000,
- app.replicator.current_partition.getRID())
- conn.ask(p, timeout=300)
+ missing_serial_set = serial_list
+ for serial in missing_serial_set:
+ conn.ask(Packets.AskObject(oid, serial, None), timeout=300)
@checkConnectionIsReplicatorConnection
def answerObject(self, conn, oid, serial_start,
@@ -134,3 +151,97 @@ class ReplicationHandler(EventHandler):
del obj
del data
+ def _doAskCheckSerialRange(self, min_oid, min_tid, length=RANGE_LENGTH):
+ replicator = self.app.replicator
+ partition = replicator.current_partition.getRID()
+ replicator.checkSerialRange(min_oid, min_tid, length, partition)
+ return Packets.AskCheckSerialRange(min_oid, min_tid, length, partition)
+
+ def _doAskCheckTIDRange(self, min_tid, length=RANGE_LENGTH):
+ replicator = self.app.replicator
+ partition = replicator.current_partition.getRID()
+ replicator.checkTIDRange(min_tid, length, partition)
+ return Packets.AskCheckTIDRange(min_tid, length, partition)
+
+ def _doAskTIDsFrom(self, min_tid, length):
+ replicator = self.app.replicator
+ partition = replicator.current_partition.getRID()
+ replicator.getTIDsFrom(min_tid, length, partition)
+ return Packets.AskTIDsFrom(min_tid, length, partition)
+
+ def _doAskObjectHistoryFrom(self, min_oid, min_serial, length):
+ replicator = self.app.replicator
+ partition = replicator.current_partition.getRID()
+ replicator.getObjectHistoryFrom(min_oid, min_serial, length, partition)
+ return Packets.AskObjectHistoryFrom(min_oid, min_serial, length,
+ partition)
+
+ @checkConnectionIsReplicatorConnection
+ def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum,
+ max_tid):
+ app = self.app
+ replicator = app.replicator
+ our = replicator.getTIDCheckResult(min_tid, length)
+ his = (count, tid_checksum, max_tid)
+ our_count = our[0]
+ our_max_tid = our[2]
+ p = None
+ if our != his:
+ # Something is different...
+ if length <= MIN_RANGE_LENGTH:
+ # We are already at minimum chunk length, replicate.
+ conn.ask(self._doAskTIDsFrom(min_tid, count))
+ else:
+ # Check a smaller chunk.
+ # Note: this could be made into a real binary search, but is
+ # it really worth the work ?
+ # Note: +1, so we can detect we reached the end when answer
+ # comes back.
+ p = self._doAskCheckTIDRange(min_tid, min(length / 2,
+ count + 1))
+ if p is None:
+ if count == length:
+ # Go on with next chunk
+ p = self._doAskCheckTIDRange(add64(max_tid, 1))
+ else:
+ # If no more TID, a replication of transactions is finished.
+ # So start to replicate objects now.
+ p = self._doAskCheckSerialRange(ZERO_OID, ZERO_TID)
+ conn.ask(p)
+
+ @checkConnectionIsReplicatorConnection
+ def answerCheckSerialRange(self, conn, min_oid, min_serial, length, count,
+ oid_checksum, max_oid, serial_checksum, max_serial):
+ app = self.app
+ replicator = app.replicator
+ our = replicator.getSerialCheckResult(min_oid, min_serial, length)
+ his = (count, oid_checksum, max_oid, serial_checksum, max_serial)
+ our_count = our[0]
+ our_max_oid = our[2]
+ our_max_serial = our[4]
+ p = None
+ if our != his:
+ # Something is different...
+ if length <= MIN_RANGE_LENGTH:
+ # We are already at minimum chunk length, replicate.
+ conn.ask(self._doAskObjectHistoryFrom(min_oid, min_serial,
+ count))
+ else:
+ # Check a smaller chunk.
+ # Note: this could be made into a real binary search, but is
+ # it really worth the work ?
+ # Note: +1, so we can detect we reached the end when answer
+ # comes back.
+ p = self._doAskCheckSerialRange(min_oid, min_serial,
+ min(length / 2, count + 1))
+ if p is None:
+ if count == length:
+ # Go on with next chunk
+ p = self._doAskCheckSerialRange(max_oid, add64(max_serial, 1))
+ else:
+ # Nothing remains, so the replication for this partition is
+ # finished.
+ replicator.replication_done = True
+ if p is not None:
+ conn.ask(p)
+
Modified: trunk/neo/storage/handlers/storage.py
==============================================================================
--- trunk/neo/storage/handlers/storage.py [iso-8859-1] (original)
+++ trunk/neo/storage/handlers/storage.py [iso-8859-1] Sun Sep 5 11:15:04 2010
@@ -30,34 +30,32 @@ class StorageOperationHandler(BaseClient
tid = app.dm.getLastTID()
conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID()))
- def askOIDs(self, conn, min_oid, length, partition):
- # This method is complicated, because I must return OIDs only
- # about usable partitions assigned to me.
- app = self.app
- if partition == protocol.INVALID_PARTITION:
- partition_list = app.pt.getAssignedPartitionList(app.uuid)
- else:
- partition_list = [partition]
-
- oid_list = app.dm.getOIDList(min_oid, length,
- app.pt.getPartitions(), partition_list)
- conn.answer(Packets.AnswerOIDs(oid_list))
-
def askTIDsFrom(self, conn, min_tid, length, partition):
- # This method is complicated, because I must return TIDs only
- # about usable partitions assigned to me.
app = self.app
- if partition == protocol.INVALID_PARTITION:
- partition_list = app.pt.getAssignedPartitionList(app.uuid)
- else:
- partition_list = [partition]
-
tid_list = app.dm.getReplicationTIDList(min_tid, length,
- app.pt.getPartitions(), partition_list)
+ app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerTIDsFrom(tid_list))
- def askObjectHistoryFrom(self, conn, oid, min_serial, length):
+ def askObjectHistoryFrom(self, conn, min_oid, min_serial, length,
+ partition):
+ app = self.app
+ object_dict = app.dm.getObjectHistoryFrom(min_oid, min_serial, length,
+ app.pt.getPartitions(), partition)
+ conn.answer(Packets.AnswerObjectHistoryFrom(object_dict))
+
+ def askCheckTIDRange(self, conn, min_tid, length, partition):
+ app = self.app
+ count, tid_checksum, max_tid = app.dm.checkTIDRange(min_tid, length,
+ app.pt.getPartitions(), partition)
+ conn.answer(Packets.AnswerCheckTIDRange(min_tid, length, count,
+ tid_checksum, max_tid))
+
+ def askCheckSerialRange(self, conn, min_oid, min_serial, length,
+ partition):
app = self.app
- history_list = app.dm.getObjectHistoryFrom(oid, min_serial, length)
- conn.answer(Packets.AnswerObjectHistoryFrom(oid, history_list))
+ count, oid_checksum, max_oid, serial_checksum, max_serial = \
+ app.dm.checkSerialRange(min_oid, min_serial, length,
+ app.pt.getPartitions(), partition)
+ conn.answer(Packets.AnswerCheckSerialRange(min_oid, min_serial, length,
+ count, oid_checksum, max_oid, serial_checksum, max_serial))
Modified: trunk/neo/storage/replicator.py
==============================================================================
--- trunk/neo/storage/replicator.py [iso-8859-1] (original)
+++ trunk/neo/storage/replicator.py [iso-8859-1] Sun Sep 5 11:15:04 2010
@@ -46,6 +46,46 @@ class Partition(object):
return tid is not None and (
min_pending_tid is None or tid < min_pending_tid)
+class Task(object):
+ """
+ A Task is a callable to execute at another time, with given parameters.
+ Execution result is kept and can be retrieved later.
+ """
+
+ _func = None
+ _args = None
+ _kw = None
+ _result = None
+ _processed = False
+
+ def __init__(self, func, args=(), kw=None):
+ self._func = func
+ self._args = args
+ if kw is None:
+ kw = {}
+ self._kw = kw
+
+ def process(self):
+ if self._processed:
+ raise ValueError, 'You cannot process a single Task twice'
+ self._processed = True
+ self._result = self._func(*self._args, **self._kw)
+
+ def getResult(self):
+ # Should we instead execute immediately rather than raising ?
+ if not self._processed:
+ raise ValueError, 'You cannot get a result until task is executed'
+ return self._result
+
+ def __repr__(self):
+ fmt = '<%s at %x %r(*%r, **%r)%%s>' % (self.__class__.__name__,
+ id(self), self._func, self._args, self._kw)
+ if self._processed:
+ extra = ' => %r' % (self._result, )
+ else:
+ extra = ''
+ return fmt % (extra, )
+
class Replicator(object):
"""This class handles replications of objects and transactions.
@@ -98,21 +138,23 @@ class Replicator(object):
# didn't answer yet.
# unfinished_tid_list
# The list of unfinished TIDs known by master node.
- # oid_list
- # List of OIDs to replicate. Doesn't contains currently-replicated
- # object.
- # XXX: not defined here
- # XXX: accessed (r/w) directly by ReplicationHandler
- # next_oid
- # Next OID to ask when oid_list is empty.
- # XXX: not defined here
- # XXX: accessed (r/w) directly by ReplicationHandler
# replication_done
# False if we know there is something to replicate.
# True when current_partition is replicated, or we don't know yet if
# there is something to replicate
# XXX: accessed (w) directly by ReplicationHandler
+ new_partition_dict = None
+ critical_tid_dict = None
+ partition_dict = None
+ task_list = None
+ task_dict = None
+ current_partition = None
+ current_connection = None
+ waiting_for_unfinished_tids = None
+ unfinished_tid_list = None
+ replication_done = None
+
def __init__(self, app):
self.app = app
@@ -129,6 +171,8 @@ class Replicator(object):
def reset(self):
"""Reset attributes to restart replicating."""
+ self.task_list = []
+ self.task_dict = {}
self.current_partition = None
self.current_connection = None
self.waiting_for_unfinished_tids = False
@@ -213,15 +257,12 @@ class Replicator(object):
p = Packets.RequestIdentification(NodeTypes.STORAGE,
app.uuid, app.server, app.name)
self.current_connection.ask(p)
-
- p = Packets.AskTIDsFrom(ZERO_TID, 1000,
- self.current_partition.getRID())
- self.current_connection.ask(p, timeout=300)
-
+ else:
+ self.current_connection.getHandler().startReplication(
+ self.current_connection)
self.replication_done = False
def _finishReplication(self):
- app = self.app
# TODO: remove try..except: pass
try:
self.partition_dict.pop(self.current_partition.getRID())
@@ -243,7 +284,11 @@ class Replicator(object):
self._askCriticalTID()
if self.current_partition is not None:
- if self.replication_done:
+ # Don't end replication until we have received all expected
+ # answers, as we might have asked object data just before the last
+ # AnswerCheckSerialRange.
+ if self.replication_done and \
+ not self.current_connection.isPending():
# finish a replication
logging.info('replication is done for %s' %
(self.current_partition.getRID(), ))
@@ -289,3 +334,57 @@ class Replicator(object):
and not self.new_partition_dict.has_key(rid):
self.new_partition_dict[rid] = Partition(rid)
+ def _addTask(self, key, func, args=(), kw=None):
+ task = Task(func, args, kw)
+ task_dict = self.task_dict
+ if key in task_dict:
+ raise ValueError, 'Task with key %r already exists (%r), cannot ' \
+ 'add %r' % (key, task_dict[key], task)
+ task_dict[key] = task
+ self.task_list.append(task)
+
+ def processDelayedTasks(self):
+ task_list = self.task_list
+ if task_list:
+ for task in task_list:
+ task.process()
+ self.task_list = []
+
+ def checkTIDRange(self, min_tid, length, partition):
+ app = self.app
+ self._addTask(('TID', min_tid, length), app.dm.checkTIDRange,
+ (min_tid, length, app.pt.getPartitions(), partition))
+
+ def checkSerialRange(self, min_oid, min_serial, length, partition):
+ app = self.app
+ self._addTask(('Serial', min_oid, min_serial, length),
+ app.dm.checkSerialRange, (min_oid, min_serial, length,
+ app.pt.getPartitions(), partition))
+
+ def getTIDsFrom(self, min_tid, length, partition):
+ app = self.app
+ self._addTask('TIDsFrom',
+ app.dm.getReplicationTIDList, (min_tid, length,
+ app.pt.getPartitions(), partition))
+
+ def getObjectHistoryFrom(self, min_oid, min_serial, length, partition):
+ app = self.app
+ self._addTask('ObjectHistoryFrom',
+ app.dm.getObjectHistoryFrom, (min_oid, min_serial, length,
+ app.pt.getPartitions(), partition))
+
+ def _getCheckResult(self, key):
+ return self.task_dict.pop(key).getResult()
+
+ def getTIDCheckResult(self, min_tid, length):
+ return self._getCheckResult(('TID', min_tid, length))
+
+ def getSerialCheckResult(self, min_oid, min_serial, length):
+ return self._getCheckResult(('Serial', min_oid, min_serial, length))
+
+ def getTIDsFromResult(self):
+ return self._getCheckResult('TIDsFrom')
+
+ def getObjectHistoryFromResult(self):
+ return self._getCheckResult('ObjectHistoryFrom')
+
Added: trunk/neo/tests/storage/testReplicationHandler.py
==============================================================================
--- trunk/neo/tests/storage/testReplicationHandler.py (added)
+++ trunk/neo/tests/storage/testReplicationHandler.py [iso-8859-1] Sun Sep 5 11:15:04 2010
@@ -0,0 +1,512 @@
+#
+# Copyright (C) 2010 Nexedi SA
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; either version 2
+# of the License, or (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+
+import unittest
+from mock import Mock
+from neo.tests import NeoTestBase
+from neo.protocol import Packets, ZERO_OID, ZERO_TID
+from neo.storage.handlers.replication import ReplicationHandler, add64
+from neo.storage.handlers.replication import RANGE_LENGTH, MIN_RANGE_LENGTH
+
+class FakeDict(object):
+ def __init__(self, items):
+ self._items = items
+ self._dict = dict(items)
+ assert len(self._dict) == len(items), self._dict
+
+ def iteritems(self):
+ for item in self._items:
+ yield item
+
+ def iterkeys(self):
+ for key, value in self.iteritems():
+ yield key
+
+ def itervalues(self):
+ for key, value in self.iteritems():
+ yield value
+
+ def items(self):
+ return self._items[:]
+
+ def keys(self):
+ return [x for x, y in self._items]
+
+ def values(self):
+ return [y for x, y in self._items]
+
+ def __getitem__(self, key):
+ return self._dict[key]
+
+ def __getattr__(self, key):
+ return getattr(self._dict, key)
+
+ def __len__(self):
+ return len(self._dict)
+
+class StorageReplicationHandlerTests(NeoTestBase):
+
+ def setup(self):
+ pass
+
+ def teardown(self):
+ pass
+
+ def getApp(self, conn=None, tid_check_result=(0, 0, ZERO_TID),
+ serial_check_result=(0, 0, ZERO_OID, 0, ZERO_TID),
+ tid_result=(),
+ history_result=None,
+ rid=0, critical_tid=ZERO_TID):
+ if history_result is None:
+ history_result = {}
+ replicator = Mock({
+ '__repr__': 'Fake replicator',
+ 'reset': None,
+ 'checkSerialRange': None,
+ 'checkTIDRange': None,
+ 'getTIDCheckResult': tid_check_result,
+ 'getSerialCheckResult': serial_check_result,
+ 'getTIDsFromResult': tid_result,
+ 'getObjectHistoryFromResult': history_result,
+ 'checkSerialRange': None,
+ 'checkTIDRange': None,
+ 'getTIDsFrom': None,
+ 'getObjectHistoryFrom': None,
+ })
+ replicator.current_partition = Mock({
+ 'getRID': rid,
+ 'getCriticalTID': critical_tid,
+ })
+ replicator.current_connection = conn
+ real_replicator = replicator
+ class FakeApp(object):
+ replicator = real_replicator
+ dm = Mock({
+ 'storeTransaction': None,
+ })
+ return FakeApp
+
+ def _checkReplicationStarted(self, conn, rid, replicator):
+ min_tid, length, partition = self.checkAskPacket(conn,
+ Packets.AskCheckTIDRange, decode=True)
+ self.assertEqual(min_tid, ZERO_TID)
+ self.assertEqual(length, RANGE_LENGTH)
+ self.assertEqual(partition, rid)
+ calls = replicator.mockGetNamedCalls('checkTIDRange')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(min_tid, length, partition)
+
+ def _checkPacketTIDList(self, conn, tid_list):
+ packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')]
+ self.assertEqual(len(packet_list), len(tid_list))
+ for packet in packet_list:
+ self.assertEqual(packet.getType(),
+ Packets.AskTransactionInformation)
+ ptid = packet.decode()[0]
+ for tid in tid_list:
+ if ptid == tid:
+ tid_list.remove(tid)
+ break
+ else:
+ raise AssertionFailed, '%s not found in %r' % (dump(ptid),
+ [dump(x) for x in tid_list])
+
+ def _checkPacketSerialList(self, conn, object_list):
+ packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')]
+ self.assertEqual(len(packet_list), len(object_list))
+ for packet, (oid, serial) in zip(packet_list, object_list):
+ self.assertEqual(packet.getType(),
+ Packets.AskObject)
+ self.assertEqual(packet.decode(), (oid, serial, None))
+
+ def test_connectionLost(self):
+ app = self.getApp()
+ ReplicationHandler(app).connectionLost(None, None)
+ self.assertEqual(len(app.replicator.mockGetNamedCalls('reset')), 1)
+
+ def test_connectionFailed(self):
+ app = self.getApp()
+ ReplicationHandler(app).connectionFailed(None)
+ self.assertEqual(len(app.replicator.mockGetNamedCalls('reset')), 1)
+
+ def test_acceptIdentification(self):
+ rid = 24
+ app = self.getApp(rid=rid)
+ conn = self.getFakeConnection()
+ replication = ReplicationHandler(app)
+ replication.acceptIdentification(conn, None, None, None,
+ None, None)
+ self._checkReplicationStarted(conn, rid, app.replicator)
+
+ def test_startReplication(self):
+ rid = 24
+ app = self.getApp(rid=rid)
+ conn = self.getFakeConnection()
+ ReplicationHandler(app).startReplication(conn)
+ self._checkReplicationStarted(conn, rid, app.replicator)
+
+ def test_answerTIDsFrom(self):
+ conn = self.getFakeConnection()
+ tid_list = [self.getNextTID(), self.getNextTID()]
+ app = self.getApp(conn=conn, tid_result=[])
+ # With no known TID
+ ReplicationHandler(app).answerTIDsFrom(conn, tid_list)
+ self._checkPacketTIDList(conn, tid_list[:])
+ # With first TID known
+ conn = self.getFakeConnection()
+ known_tid_list = [tid_list[0], ]
+ unknown_tid_list = [tid_list[1], ]
+ app = self.getApp(conn=conn, tid_result=known_tid_list)
+ ReplicationHandler(app).answerTIDsFrom(conn, tid_list)
+ self._checkPacketTIDList(conn, unknown_tid_list)
+
+ def test_answerTransactionInformation(self):
+ conn = self.getFakeConnection()
+ app = self.getApp(conn=conn)
+ tid = self.getNextTID()
+ user = 'foo'
+ desc = 'bar'
+ ext = 'baz'
+ packed = True
+ oid_list = [self.getOID(1), self.getOID(2)]
+ ReplicationHandler(app).answerTransactionInformation(conn, tid, user,
+ desc, ext, packed, oid_list)
+ calls = app.dm.mockGetNamedCalls('storeTransaction')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(tid, (), (oid_list, user, desc, ext, packed), False)
+
+ def test_answerObjectHistoryFrom(self):
+ conn = self.getFakeConnection()
+ oid_1 = self.getOID(1)
+ oid_2 = self.getOID(2)
+ oid_3 = self.getOID(3)
+ oid_dict = FakeDict((
+ (oid_1, [self.getNextTID(), self.getNextTID()]),
+ (oid_2, [self.getNextTID()]),
+ (oid_3, [self.getNextTID()]),
+ ))
+ flat_oid_list = []
+ for oid, serial_list in oid_dict.iteritems():
+ for serial in serial_list:
+ flat_oid_list.append((oid, serial))
+ app = self.getApp(conn=conn, history_result={})
+ # With no known OID/Serial
+ ReplicationHandler(app).answerObjectHistoryFrom(conn, oid_dict)
+ self._checkPacketSerialList(conn, flat_oid_list)
+ # With some known OID/Serials
+ conn = self.getFakeConnection()
+ app = self.getApp(conn=conn, history_result={
+ oid_1: [oid_dict[oid_1][0], ],
+ oid_3: [oid_dict[oid_3][0], ],
+ })
+ ReplicationHandler(app).answerObjectHistoryFrom(conn, oid_dict)
+ self._checkPacketSerialList(conn, (
+ (oid_1, oid_dict[oid_1][1]),
+ (oid_2, oid_dict[oid_2][0]),
+ ))
+
+ def test_answerObject(self):
+ conn = self.getFakeConnection()
+ app = self.getApp(conn=conn)
+ oid = self.getOID(1)
+ serial_start = self.getNextTID()
+ serial_end = self.getNextTID()
+ compression = 1
+ checksum = 2
+ data = 'foo'
+ data_serial = None
+ ReplicationHandler(app).answerObject(conn, oid, serial_start,
+ serial_end, compression, checksum, data, data_serial)
+ calls = app.dm.mockGetNamedCalls('storeTransaction')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(serial_start, [(oid, compression, checksum, data,
+ data_serial)], None, False)
+
+ # CheckTIDRange
+ def test_answerCheckTIDRangeIdenticalChunkWithNext(self):
+ min_tid = self.getNextTID()
+ max_tid = self.getNextTID()
+ length = RANGE_LENGTH / 2
+ rid = 12
+ conn = self.getFakeConnection()
+ app = self.getApp(tid_check_result=(length, 0, max_tid), rid=rid,
+ conn=conn)
+ handler = ReplicationHandler(app)
+ # Peer has the same data as we have: length, checksum and max_tid
+ # match.
+ handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
+ # Result: go on with next chunk
+ pmin_tid, plength, ppartition = self.checkAskPacket(conn,
+ Packets.AskCheckTIDRange, decode=True)
+ self.assertEqual(pmin_tid, add64(max_tid, 1))
+ self.assertEqual(plength, RANGE_LENGTH)
+ self.assertEqual(ppartition, rid)
+ calls = app.replicator.mockGetNamedCalls('checkTIDRange')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(pmin_tid, plength, ppartition)
+
+ def test_answerCheckTIDRangeIdenticalChunkWithoutNext(self):
+ min_tid = self.getNextTID()
+ max_tid = self.getNextTID()
+ length = RANGE_LENGTH / 2
+ rid = 12
+ conn = self.getFakeConnection()
+ app = self.getApp(tid_check_result=(length - 1, 0, max_tid), rid=rid,
+ conn=conn)
+ handler = ReplicationHandler(app)
+ # Peer has the same data as we have: length, checksum and max_tid
+ # match.
+ handler.answerCheckTIDRange(conn, min_tid, length, length - 1, 0,
+ max_tid)
+ # Result: go on with object range checks
+ pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn,
+ Packets.AskCheckSerialRange, decode=True)
+ self.assertEqual(pmin_oid, ZERO_OID)
+ self.assertEqual(pmin_serial, ZERO_TID)
+ self.assertEqual(plength, RANGE_LENGTH)
+ self.assertEqual(ppartition, rid)
+ calls = app.replicator.mockGetNamedCalls('checkSerialRange')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
+
+ def test_answerCheckTIDRangeDifferentBigChunk(self):
+ min_tid = self.getNextTID()
+ max_tid = self.getNextTID()
+ length = RANGE_LENGTH / 2
+ rid = 12
+ conn = self.getFakeConnection()
+ app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
+ conn=conn)
+ handler = ReplicationHandler(app)
+ # Peer has different data
+ handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
+ # Result: ask again, length halved
+ pmin_tid, plength, ppartition = self.checkAskPacket(conn,
+ Packets.AskCheckTIDRange, decode=True)
+ self.assertEqual(pmin_tid, min_tid)
+ self.assertEqual(plength, length / 2)
+ self.assertEqual(ppartition, rid)
+ calls = app.replicator.mockGetNamedCalls('checkTIDRange')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(pmin_tid, plength, ppartition)
+
+ def test_answerCheckTIDRangeDifferentSmallChunkWithNext(self):
+ min_tid = self.getNextTID()
+ max_tid = self.getNextTID()
+ length = MIN_RANGE_LENGTH - 1
+ rid = 12
+ conn = self.getFakeConnection()
+ app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
+ conn=conn)
+ handler = ReplicationHandler(app)
+ # Peer has different data
+ handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
+ # Result: ask tid list, and ask next chunk
+ calls = conn.mockGetNamedCalls('ask')
+ self.assertEqual(len(calls), 2)
+ tid_call, next_call = calls
+ tid_packet = tid_call.getParam(0)
+ next_packet = next_call.getParam(0)
+ self.assertEqual(tid_packet.getType(), Packets.AskTIDsFrom)
+ pmin_tid, plength, ppartition = tid_packet.decode()
+ self.assertEqual(pmin_tid, min_tid)
+ self.assertEqual(plength, length)
+ self.assertEqual(ppartition, rid)
+ calls = app.replicator.mockGetNamedCalls('getTIDsFrom')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(pmin_tid, plength, ppartition)
+ self.assertEqual(next_packet.getType(), Packets.AskCheckTIDRange)
+ pmin_tid, plength, ppartition = next_packet.decode()
+ self.assertEqual(pmin_tid, add64(max_tid, 1))
+ self.assertEqual(plength, RANGE_LENGTH)
+ self.assertEqual(ppartition, rid)
+ calls = app.replicator.mockGetNamedCalls('checkTIDRange')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(pmin_tid, plength, ppartition)
+
+ def test_answerCheckTIDRangeDifferentSmallChunkWithoutNext(self):
+ min_tid = self.getNextTID()
+ max_tid = self.getNextTID()
+ length = MIN_RANGE_LENGTH - 1
+ rid = 12
+ conn = self.getFakeConnection()
+ app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
+ conn=conn)
+ handler = ReplicationHandler(app)
+ # Peer has different data, and less than length
+ handler.answerCheckTIDRange(conn, min_tid, length, length - 1, 0,
+ max_tid)
+ # Result: ask tid list, and start replicating object range
+ calls = conn.mockGetNamedCalls('ask')
+ self.assertEqual(len(calls), 2)
+ tid_call, next_call = calls
+ tid_packet = tid_call.getParam(0)
+ next_packet = next_call.getParam(0)
+ self.assertEqual(tid_packet.getType(), Packets.AskTIDsFrom)
+ pmin_tid, plength, ppartition = tid_packet.decode()
+ self.assertEqual(pmin_tid, min_tid)
+ self.assertEqual(plength, length - 1)
+ self.assertEqual(ppartition, rid)
+ calls = app.replicator.mockGetNamedCalls('getTIDsFrom')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(pmin_tid, plength, ppartition)
+ self.assertEqual(next_packet.getType(), Packets.AskCheckSerialRange)
+ pmin_oid, pmin_serial, plength, ppartition = next_packet.decode()
+ self.assertEqual(pmin_oid, ZERO_OID)
+ self.assertEqual(pmin_serial, ZERO_TID)
+ self.assertEqual(plength, RANGE_LENGTH)
+ self.assertEqual(ppartition, rid)
+ calls = app.replicator.mockGetNamedCalls('checkSerialRange')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
+
+ # CheckSerialRange
+ def test_answerCheckSerialRangeIdenticalChunkWithNext(self):
+ min_oid = self.getOID(1)
+ max_oid = self.getOID(10)
+ min_serial = self.getNextTID()
+ max_serial = self.getNextTID()
+ length = RANGE_LENGTH / 2
+ rid = 12
+ conn = self.getFakeConnection()
+ app = self.getApp(serial_check_result=(length, 0, max_oid, 1,
+ max_serial), rid=rid, conn=conn)
+ handler = ReplicationHandler(app)
+ # Peer has the same data as we have
+ handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
+ length, 0, max_oid, 1, max_serial)
+ # Result: go on with next chunk
+ pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn,
+ Packets.AskCheckSerialRange, decode=True)
+ self.assertEqual(pmin_oid, max_oid)
+ self.assertEqual(pmin_serial, add64(max_serial, 1))
+ self.assertEqual(plength, RANGE_LENGTH)
+ self.assertEqual(ppartition, rid)
+ calls = app.replicator.mockGetNamedCalls('checkSerialRange')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
+
+ def test_answerCheckSerialRangeIdenticalChunkWithoutNext(self):
+ min_oid = self.getOID(1)
+ max_oid = self.getOID(10)
+ min_serial = self.getNextTID()
+ max_serial = self.getNextTID()
+ length = RANGE_LENGTH / 2
+ rid = 12
+ conn = self.getFakeConnection()
+ app = self.getApp(serial_check_result=(length - 1, 0, max_oid, 1,
+ max_serial), rid=rid, conn=conn)
+ handler = ReplicationHandler(app)
+ # Peer has the same data as we have
+ handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
+ length - 1, 0, max_oid, 1, max_serial)
+ # Result: mark replication as done
+ self.checkNoPacketSent(conn)
+ self.assertTrue(app.replicator.replication_done)
+
+ def test_answerCheckSerialRangeDifferentBigChunk(self):
+ min_oid = self.getOID(1)
+ max_oid = self.getOID(10)
+ min_serial = self.getNextTID()
+ max_serial = self.getNextTID()
+ length = RANGE_LENGTH / 2
+ rid = 12
+ conn = self.getFakeConnection()
+ app = self.getApp(tid_check_result=(length - 5, 0, max_oid, 1,
+ max_serial), rid=rid, conn=conn)
+ handler = ReplicationHandler(app)
+ # Peer has different data
+ handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
+ length, 0, max_oid, 1, max_serial)
+ # Result: ask again, length halved
+ pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn,
+ Packets.AskCheckSerialRange, decode=True)
+ self.assertEqual(pmin_oid, min_oid)
+ self.assertEqual(pmin_serial, min_serial)
+ self.assertEqual(plength, length / 2)
+ self.assertEqual(ppartition, rid)
+ calls = app.replicator.mockGetNamedCalls('checkSerialRange')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
+
+ def test_answerCheckSerialRangeDifferentSmallChunkWithNext(self):
+ min_oid = self.getOID(1)
+ max_oid = self.getOID(10)
+ min_serial = self.getNextTID()
+ max_serial = self.getNextTID()
+ length = MIN_RANGE_LENGTH - 1
+ rid = 12
+ conn = self.getFakeConnection()
+ app = self.getApp(tid_check_result=(length - 5, 0, max_oid, 1,
+ max_serial), rid=rid, conn=conn)
+ handler = ReplicationHandler(app)
+ # Peer has different data
+ handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
+ length, 0, max_oid, 1, max_serial)
+ # Result: ask serial list, and ask next chunk
+ calls = conn.mockGetNamedCalls('ask')
+ self.assertEqual(len(calls), 2)
+ serial_call, next_call = calls
+ serial_packet = serial_call.getParam(0)
+ next_packet = next_call.getParam(0)
+ self.assertEqual(serial_packet.getType(), Packets.AskObjectHistoryFrom)
+ pmin_oid, pmin_serial, plength, ppartition = serial_packet.decode()
+ self.assertEqual(pmin_oid, min_oid)
+ self.assertEqual(pmin_serial, min_serial)
+ self.assertEqual(plength, length)
+ self.assertEqual(ppartition, rid)
+ calls = app.replicator.mockGetNamedCalls('getObjectHistoryFrom')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
+ self.assertEqual(next_packet.getType(), Packets.AskCheckSerialRange)
+ pmin_oid, pmin_serial, plength, ppartition = next_packet.decode()
+ self.assertEqual(pmin_oid, max_oid)
+ self.assertEqual(pmin_serial, add64(max_serial, 1))
+ self.assertEqual(plength, RANGE_LENGTH)
+ self.assertEqual(ppartition, rid)
+ calls = app.replicator.mockGetNamedCalls('checkSerialRange')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
+
+ def test_answerCheckSerialRangeDifferentSmallChunkWithoutNext(self):
+ min_oid = self.getOID(1)
+ max_oid = self.getOID(10)
+ min_serial = self.getNextTID()
+ max_serial = self.getNextTID()
+ length = MIN_RANGE_LENGTH - 1
+ rid = 12
+ conn = self.getFakeConnection()
+ app = self.getApp(tid_check_result=(length - 5, 0, max_oid,
+ 1, max_serial), rid=rid, conn=conn)
+ handler = ReplicationHandler(app)
+ # Peer has different data, and less than length
+ handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
+ length - 1, 0, max_oid, 1, max_serial)
+ # Result: ask tid list, and mark replication as done
+ pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn,
+ Packets.AskObjectHistoryFrom, decode=True)
+ self.assertEqual(pmin_oid, min_oid)
+ self.assertEqual(pmin_serial, min_serial)
+ self.assertEqual(plength, length - 1)
+ self.assertEqual(ppartition, rid)
+ calls = app.replicator.mockGetNamedCalls('getObjectHistoryFrom')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
+ self.assertTrue(app.replicator.replication_done)
+
+if __name__ == "__main__":
+ unittest.main()
Added: trunk/neo/tests/storage/testReplicator.py
==============================================================================
--- trunk/neo/tests/storage/testReplicator.py (added)
+++ trunk/neo/tests/storage/testReplicator.py [iso-8859-1] Sun Sep 5 11:15:04 2010
@@ -0,0 +1,262 @@
+#
+# Copyright (C) 2010 Nexedi SA
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; either version 2
+# of the License, or (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+
+import unittest
+from mock import Mock, ReturnValues
+from neo.tests import NeoTestBase
+from neo.storage.replicator import Replicator, Partition, Task
+from neo.protocol import CellStates, NodeStates, Packets
+
+class StorageReplicatorTests(NeoTestBase):
+
+ def setup(self):
+ pass
+
+ def teardown(self):
+ pass
+
+ def test_populate(self):
+ my_uuid = self.getNewUUID()
+ other_uuid = self.getNewUUID()
+ app = Mock()
+ app.uuid = my_uuid
+ app.pt = Mock({
+ 'getPartitions': 2,
+ 'getRow': ReturnValues(
+ ((my_uuid, CellStates.OUT_OF_DATE),
+ (other_uuid, CellStates.UP_TO_DATE), ),
+ ((my_uuid, CellStates.UP_TO_DATE),
+ (other_uuid, CellStates.OUT_OF_DATE), ),
+ ),
+ })
+ replicator = Replicator(app)
+ assert replicator.new_partition_dict is None, \
+ replicator.new_partition_dict
+ assert replicator.critical_tid_dict is None, \
+ replicator.critical_tid_dict
+ assert replicator.partition_dict is None, replicator.partition_dict
+ replicator.populate()
+ self.assertEqual(len(replicator.new_partition_dict), 1)
+ partition = replicator.new_partition_dict[0]
+ self.assertEqual(partition.getRID(), 0)
+ self.assertEqual(partition.getCriticalTID(), None)
+ self.assertEqual(replicator.critical_tid_dict, {})
+ self.assertEqual(replicator.partition_dict, {})
+
+ def test_reset(self):
+ replicator = Replicator(None)
+ assert replicator.task_list is None, replicator.task_list
+ assert replicator.task_dict is None, replicator.task_dict
+ assert replicator.current_partition is None, \
+ replicator.current_partition
+ assert replicator.current_connection is None, \
+ replicator.current_connection
+ assert replicator.waiting_for_unfinished_tids is None, \
+ replicator.waiting_for_unfinished_tids
+ assert replicator.unfinished_tid_list is None, \
+ replicator.unfinished_tid_list
+ assert replicator.replication_done is None, replicator.replication_done
+ replicator.reset()
+ self.assertEqual(replicator.task_list, [])
+ self.assertEqual(replicator.task_dict, {})
+ self.assertEqual(replicator.current_partition, None)
+ self.assertEqual(replicator.current_connection, None)
+ self.assertEqual(replicator.waiting_for_unfinished_tids, False)
+ self.assertEqual(replicator.unfinished_tid_list, None)
+ self.assertEqual(replicator.replication_done, True)
+
+ def test_setCriticalTID(self):
+ replicator = Replicator(None)
+ master_uuid = self.getNewUUID()
+ partition_list = [Partition(0), Partition(5)]
+ replicator.critical_tid_dict = {master_uuid: partition_list}
+ critical_tid = self.getNextTID()
+ for partition in partition_list:
+ self.assertEqual(partition.getCriticalTID(), None)
+ replicator.setCriticalTID(master_uuid, critical_tid)
+ self.assertEqual(replicator.critical_tid_dict, {})
+ for partition in partition_list:
+ self.assertEqual(partition.getCriticalTID(), critical_tid)
+
+ def test_setUnfinishedTIDList(self):
+ replicator = Replicator(None)
+ replicator.waiting_for_unfinished_tids = True
+ assert replicator.unfinished_tid_list is None, \
+ replicator.unfinished_tid_list
+ tid_list = [self.getNextTID(), ]
+ replicator.setUnfinishedTIDList(tid_list)
+ self.assertEqual(replicator.unfinished_tid_list, tid_list)
+ self.assertFalse(replicator.waiting_for_unfinished_tids)
+
+ def test_act(self):
+ # Also tests "pending"
+ uuid = self.getNewUUID()
+ master_uuid = self.getNewUUID()
+ bad_unfinished_tid = self.getNextTID()
+ critical_tid = self.getNextTID()
+ unfinished_tid = self.getNextTID()
+ app = Mock()
+ app.em = Mock({
+ 'register': None,
+ })
+ def connectorGenerator():
+ return Mock()
+ app.connector_handler = connectorGenerator
+ app.uuid = uuid
+ node_addr = ('127.0.0.1', 1234)
+ node = Mock({
+ 'getAddress': node_addr,
+ })
+ running_cell = Mock({
+ 'getNodeState': NodeStates.RUNNING,
+ 'getNode': node,
+ })
+ unknown_cell = Mock({
+ 'getNodeState': NodeStates.UNKNOWN,
+ })
+ app.pt = Mock({
+ 'getPartitions': 1,
+ 'getRow': ReturnValues(
+ ((uuid, CellStates.OUT_OF_DATE), ),
+ ),
+ 'getCellList': [running_cell, unknown_cell],
+ })
+ node_conn_handler = Mock({
+ 'startReplication': None,
+ })
+ node_conn = Mock({
+ 'getAddress': node_addr,
+ 'getHandler': node_conn_handler,
+ })
+ replicator = Replicator(app)
+ replicator.populate()
+ def act():
+ app.master_conn = self.getFakeConnection(uuid=master_uuid)
+ self.assertTrue(replicator.pending())
+ replicator.act()
+ # ask last IDs to infer critical_tid and unfinished tids
+ act()
+ last_ids, unfinished_tids = [x.getParam(0) for x in \
+ app.master_conn.mockGetNamedCalls('ask')]
+ self.assertEqual(last_ids.getType(), Packets.AskLastIDs)
+ self.assertFalse(replicator.new_partition_dict)
+ self.assertEqual(unfinished_tids.getType(),
+ Packets.AskUnfinishedTransactions)
+ self.assertTrue(replicator.waiting_for_unfinished_tids)
+ # nothing happens until waiting_for_unfinished_tids becomes False
+ act()
+ self.checkNoPacketSent(app.master_conn)
+ self.assertTrue(replicator.waiting_for_unfinished_tids)
+ # Send answers (garanteed to happen in this order)
+ replicator.setCriticalTID(master_uuid, critical_tid)
+ act()
+ self.checkNoPacketSent(app.master_conn)
+ self.assertTrue(replicator.waiting_for_unfinished_tids)
+ # first time, there is an unfinished tid before critical tid,
+ # replication cannot start, and unfinished TIDs are asked again
+ replicator.setUnfinishedTIDList([unfinished_tid, bad_unfinished_tid])
+ self.assertFalse(replicator.waiting_for_unfinished_tids)
+ # Note: detection that nothing can be replicated happens on first call
+ # and unfinished tids are asked again on second call. This is ok, but
+ # might change, so just call twice.
+ act()
+ act()
+ self.checkAskPacket(app.master_conn, Packets.AskUnfinishedTransactions)
+ self.assertTrue(replicator.waiting_for_unfinished_tids)
+ # this time, critical tid check should be satisfied
+ replicator.setUnfinishedTIDList([unfinished_tid, ])
+ replicator.current_connection = node_conn
+ act()
+ self.assertEqual(replicator.current_partition,
+ replicator.partition_dict[0])
+ self.assertEqual(len(node_conn_handler.mockGetNamedCalls(
+ 'startReplication')), 1)
+ self.assertFalse(replicator.replication_done)
+ # Other calls should do nothing
+ replicator.current_connection = Mock()
+ act()
+ self.checkNoPacketSent(app.master_conn)
+ self.checkNoPacketSent(replicator.current_connection)
+ # Mark replication over for this partition
+ replicator.replication_done = True
+ # Don't finish while there are pending answers
+ replicator.current_connection = Mock({
+ 'isPending': True,
+ })
+ act()
+ self.assertTrue(replicator.pending())
+ replicator.current_connection = Mock({
+ 'isPending': False,
+ })
+ act()
+ # unfinished tid list will not be asked again
+ self.assertTrue(replicator.unfinished_tid_list)
+ # also, replication is over
+ self.assertFalse(replicator.pending())
+
+ def test_removePartition(self):
+ replicator = Replicator(None)
+ replicator.partition_dict = {0: None, 2: None}
+ replicator.new_partition_dict = {1: None}
+ replicator.removePartition(0)
+ self.assertEqual(replicator.partition_dict, {2: None})
+ self.assertEqual(replicator.new_partition_dict, {1: None})
+ replicator.removePartition(1)
+ replicator.removePartition(2)
+ self.assertEqual(replicator.partition_dict, {})
+ self.assertEqual(replicator.new_partition_dict, {})
+ # Must not raise
+ replicator.removePartition(3)
+
+ def test_addPartition(self):
+ replicator = Replicator(None)
+ replicator.partition_dict = {0: None}
+ replicator.new_partition_dict = {1: None}
+ replicator.addPartition(0)
+ replicator.addPartition(1)
+ self.assertEqual(replicator.partition_dict, {0: None})
+ self.assertEqual(replicator.new_partition_dict, {1: None})
+ replicator.addPartition(2)
+ self.assertEqual(replicator.partition_dict, {0: None})
+ self.assertEqual(len(replicator.new_partition_dict), 2)
+ self.assertEqual(replicator.new_partition_dict[1], None)
+ partition = replicator.new_partition_dict[2]
+ self.assertEqual(partition.getRID(), 2)
+ self.assertEqual(partition.getCriticalTID(), None)
+
+ def test_processDelayedTasks(self):
+ replicator = Replicator(None)
+ replicator.reset()
+ marker = []
+ def someCallable(foo, bar=None):
+ return (foo, bar)
+ replicator._addTask(1, someCallable, args=('foo', ))
+ self.assertRaises(ValueError, replicator._addTask, 1, None)
+ replicator._addTask(2, someCallable, args=('foo', ), kw={'bar': 'bar'})
+ replicator.processDelayedTasks()
+ self.assertEqual(replicator._getCheckResult(1), ('foo', None))
+ self.assertEqual(replicator._getCheckResult(2), ('foo', 'bar'))
+ # Also test Task
+ task = Task(someCallable, args=('foo', ))
+ self.assertRaises(ValueError, task.getResult)
+ task.process()
+ self.assertRaises(ValueError, task.process)
+ self.assertEqual(task.getResult(), ('foo', None))
+
+if __name__ == "__main__":
+ unittest.main()
+
Modified: trunk/neo/tests/storage/testStorageHandler.py
==============================================================================
--- trunk/neo/tests/storage/testStorageHandler.py [iso-8859-1] (original)
+++ trunk/neo/tests/storage/testStorageHandler.py [iso-8859-1] Sun Sep 5 11:15:04 2010
@@ -21,7 +21,7 @@ from collections import deque
from neo.tests import NeoTestBase
from neo.storage.app import Application
from neo.storage.handlers.storage import StorageOperationHandler
-from neo.protocol import INVALID_PARTITION
+from neo.protocol import INVALID_PARTITION, Packets
from neo.protocol import INVALID_TID, INVALID_OID, INVALID_SERIAL
class StorageStorageHandlerTests(NeoTestBase):
@@ -113,7 +113,7 @@ class StorageStorageHandlerTests(NeoTest
self.assertEquals(len(self.app.event_queue), 0)
self.checkAnswerObject(conn)
- def test_25_askTIDsFrom1(self):
+ def test_25_askTIDsFrom(self):
# well case => answer
conn = self.getFakeConnection()
self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )})
@@ -122,69 +122,85 @@ class StorageStorageHandlerTests(NeoTest
self.operation.askTIDsFrom(conn, tid, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEquals(len(calls), 1)
- calls[0].checkArgs(tid, 2, 1, [1, ])
- self.checkAnswerTidsFrom(conn)
-
- def test_25_askTIDsFrom2(self):
- # invalid partition => answer usable partitions
- conn = self.getFakeConnection()
- cell = Mock({'getUUID':self.app.uuid})
- self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )})
- self.app.pt = Mock({
- 'getCellList': (cell, ),
- 'getPartitions': 1,
- 'getAssignedPartitionList': [0],
- })
- tid = self.getNextTID()
- self.operation.askTIDsFrom(conn, tid, 2, INVALID_PARTITION)
- self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
- calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
- self.assertEquals(len(calls), 1)
- calls[0].checkArgs(tid, 2, 1, [0, ])
+ calls[0].checkArgs(tid, 2, 1, 1)
self.checkAnswerTidsFrom(conn)
def test_26_askObjectHistoryFrom(self):
- oid = self.getOID(2)
- min_tid = self.getNextTID()
+ min_oid = self.getOID(2)
+ min_serial = self.getNextTID()
+ length = 4
+ partition = 8
+ num_partitions = 16
tid = self.getNextTID()
conn = self.getFakeConnection()
- self.app.dm = Mock({'getObjectHistoryFrom': [tid]})
- self.operation.askObjectHistoryFrom(conn, oid, min_tid, 2)
- self.checkAnswerObjectHistoryFrom(conn)
- calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom')
- self.assertEquals(len(calls), 1)
- calls[0].checkArgs(oid, min_tid, 2)
-
- def test_25_askOIDs1(self):
- # well case > answer OIDs
- conn = self.getFakeConnection()
- self.app.pt = Mock({'getPartitions': 1})
- self.app.dm = Mock({'getOIDList': (INVALID_OID, )})
- oid = self.getOID(1)
- self.operation.askOIDs(conn, oid, 2, 1)
- calls = self.app.dm.mockGetNamedCalls('getOIDList')
- self.assertEquals(len(calls), 1)
- calls[0].checkArgs(oid, 2, 1, [1, ])
- self.checkAnswerOids(conn)
-
- def test_25_askOIDs2(self):
- # invalid partition => answer usable partitions
- conn = self.getFakeConnection()
- cell = Mock({'getUUID':self.app.uuid})
- self.app.dm = Mock({'getOIDList': (INVALID_OID, )})
+ self.app.dm = Mock({'getObjectHistoryFrom': {min_oid: [tid]},})
self.app.pt = Mock({
- 'getCellList': (cell, ),
- 'getPartitions': 1,
- 'getAssignedPartitionList': [0],
+ 'getPartitions': num_partitions,
})
- oid = self.getOID(1)
- self.operation.askOIDs(conn, oid, 2, INVALID_PARTITION)
- self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
- calls = self.app.dm.mockGetNamedCalls('getOIDList')
+ self.operation.askObjectHistoryFrom(conn, min_oid, min_serial, length,
+ partition)
+ self.checkAnswerObjectHistoryFrom(conn)
+ calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEquals(len(calls), 1)
- calls[0].checkArgs(oid, 2, 1, [0])
- self.checkAnswerOids(conn)
+ calls[0].checkArgs(min_oid, min_serial, length, num_partitions,
+ partition)
+ def test_askCheckTIDRange(self):
+ count = 1
+ tid_checksum = 2
+ min_tid = self.getNextTID()
+ num_partitions = 4
+ length = 5
+ partition = 6
+ max_tid = self.getNextTID()
+ self.app.dm = Mock({'checkTIDRange': (count, tid_checksum, max_tid)})
+ self.app.pt = Mock({'getPartitions': num_partitions})
+ conn = self.getFakeConnection()
+ self.operation.askCheckTIDRange(conn, min_tid, length, partition)
+ calls = self.app.dm.mockGetNamedCalls('checkTIDRange')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(min_tid, length, num_partitions, partition)
+ pmin_tid, plength, pcount, ptid_checksum, pmax_tid = \
+ self.checkAnswerPacket(conn, Packets.AnswerCheckTIDRange,
+ decode=True)
+ self.assertEqual(min_tid, pmin_tid)
+ self.assertEqual(length, plength)
+ self.assertEqual(count, pcount)
+ self.assertEqual(tid_checksum, ptid_checksum)
+ self.assertEqual(max_tid, pmax_tid)
+
+ def test_askCheckSerialRange(self):
+ count = 1
+ oid_checksum = 2
+ min_oid = self.getOID(1)
+ num_partitions = 4
+ length = 5
+ partition = 6
+ serial_checksum = 7
+ min_serial = self.getNextTID()
+ max_serial = self.getNextTID()
+ max_oid = self.getOID(2)
+ self.app.dm = Mock({'checkSerialRange': (count, oid_checksum, max_oid,
+ serial_checksum, max_serial)})
+ self.app.pt = Mock({'getPartitions': num_partitions})
+ conn = self.getFakeConnection()
+ self.operation.askCheckSerialRange(conn, min_oid, min_serial, length,
+ partition)
+ calls = self.app.dm.mockGetNamedCalls('checkSerialRange')
+ self.assertEqual(len(calls), 1)
+ calls[0].checkArgs(min_oid, min_serial, length, num_partitions,
+ partition)
+ pmin_oid, pmin_serial, plength, pcount, poid_checksum, pmax_oid, \
+ pserial_checksum, pmax_serial = self.checkAnswerPacket(conn,
+ Packets.AnswerCheckSerialRange, decode=True)
+ self.assertEqual(min_oid, pmin_oid)
+ self.assertEqual(min_serial, pmin_serial)
+ self.assertEqual(length, plength)
+ self.assertEqual(count, pcount)
+ self.assertEqual(oid_checksum, poid_checksum)
+ self.assertEqual(max_oid, pmax_oid)
+ self.assertEqual(serial_checksum, pserial_checksum)
+ self.assertEqual(max_serial, pmax_serial)
if __name__ == "__main__":
unittest.main()
Modified: trunk/neo/tests/storage/testStorageMySQLdb.py
==============================================================================
--- trunk/neo/tests/storage/testStorageMySQLdb.py [iso-8859-1] (original)
+++ trunk/neo/tests/storage/testStorageMySQLdb.py [iso-8859-1] Sun Sep 5 11:15:04 2010
@@ -19,7 +19,7 @@ import unittest
import MySQLdb
from mock import Mock
from neo.util import dump, p64, u64
-from neo.protocol import CellStates, INVALID_PTID
+from neo.protocol import CellStates, INVALID_PTID, ZERO_OID, ZERO_TID
from neo.tests import NeoTestBase
from neo.exception import DatabaseFailure
from neo.storage.database.mysqldb import MySQLDatabaseManager
@@ -441,6 +441,23 @@ class StorageMySQSLdbTests(NeoTestBase):
self.assertEqual(self.db.getTransaction(tid1, True), None)
self.assertEqual(self.db.getTransaction(tid2, True), None)
+ def test_deleteObject(self):
+ oid1, oid2 = self.getOIDs(2)
+ tid1, tid2 = self.getTIDs(2)
+ txn1, objs1 = self.getTransaction([oid1, oid2])
+ txn2, objs2 = self.getTransaction([oid1, oid2])
+ self.db.storeTransaction(tid1, objs1, txn1)
+ self.db.storeTransaction(tid2, objs2, txn2)
+ self.db.finishTransaction(tid1)
+ self.db.finishTransaction(tid2)
+ self.db.deleteObject(oid1)
+ self.assertEqual(self.db.getObject(oid1, tid=tid1), None)
+ self.assertEqual(self.db.getObject(oid1, tid=tid2), None)
+ self.db.deleteObject(oid2, serial=tid1)
+ self.assertEqual(self.db.getObject(oid2, tid=tid1), False)
+ self.assertEqual(self.db.getObject(oid2, tid=tid2), (tid2, None) + \
+ objs2[1][1:])
+
def test_getTransaction(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
@@ -459,30 +476,6 @@ class StorageMySQSLdbTests(NeoTestBase):
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False))
self.assertEqual(self.db.getTransaction(tid2, False), None)
- def test_getOIDList(self):
- # store four objects
- oid1, oid2, oid3, oid4 = self.getOIDs(4)
- tid = self.getNextTID()
- txn, objs = self.getTransaction([oid1, oid2, oid3, oid4])
- self.db.storeTransaction(tid, objs, txn)
- self.db.finishTransaction(tid)
- # get oids
- result = self.db.getOIDList(oid1, 4, 1, [0])
- self.checkSet(result, [oid1, oid2, oid3, oid4])
- result = self.db.getOIDList(oid1, 4, 2, [0])
- self.checkSet(result, [oid1, oid3])
- result = self.db.getOIDList(oid1, 4, 2, [0, 1])
- self.checkSet(result, [oid1, oid2, oid3, oid4])
- result = self.db.getOIDList(oid1, 4, 3, [0])
- self.checkSet(result, [oid1, oid4])
- # get a subset of oids
- result = self.db.getOIDList(oid1, 2, 1, [0])
- self.checkSet(result, [oid1, oid2])
- result = self.db.getOIDList(oid3, 2, 1, [0])
- self.checkSet(result, [oid3, oid4])
- result = self.db.getOIDList(oid2, 1, 3, [0])
- self.checkSet(result, [oid4])
-
def test_getObjectHistory(self):
oid = self.getOID(1)
tid1, tid2, tid3 = self.getTIDs(3)
@@ -506,6 +499,50 @@ class StorageMySQSLdbTests(NeoTestBase):
result = self.db.getObjectHistory(oid, 2, 3)
self.assertEqual(result, None)
+ def test_getObjectHistoryFrom(self):
+ oid1 = self.getOID(0)
+ oid2 = self.getOID(1)
+ tid1, tid2, tid3, tid4 = self.getTIDs(4)
+ txn1, objs1 = self.getTransaction([oid1])
+ txn2, objs2 = self.getTransaction([oid2])
+ txn3, objs3 = self.getTransaction([oid1])
+ txn4, objs4 = self.getTransaction([oid2])
+ self.db.storeTransaction(tid1, objs1, txn1)
+ self.db.storeTransaction(tid2, objs2, txn2)
+ self.db.storeTransaction(tid3, objs3, txn3)
+ self.db.storeTransaction(tid4, objs4, txn4)
+ self.db.finishTransaction(tid1)
+ self.db.finishTransaction(tid2)
+ self.db.finishTransaction(tid3)
+ self.db.finishTransaction(tid4)
+ # Check full result
+ result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 1, 0)
+ self.assertEqual(result, {
+ oid1: [tid1, tid3],
+ oid2: [tid2, tid4],
+ })
+ # Lower bound is inclusive
+ result = self.db.getObjectHistoryFrom(oid1, tid1, 10, 1, 0)
+ self.assertEqual(result, {
+ oid1: [tid1, tid3],
+ oid2: [tid2, tid4],
+ })
+ # Length is total number of serials
+ result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 3, 1, 0)
+ self.assertEqual(result, {
+ oid1: [tid1, tid3],
+ oid2: [tid2],
+ })
+ # Partition constraints are honored
+ result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 2, 0)
+ self.assertEqual(result, {
+ oid1: [tid1, tid3],
+ })
+ result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 2, 1)
+ self.assertEqual(result, {
+ oid2: [tid2, tid4],
+ })
+
def _storeTransactions(self, count):
# use OID generator to know result of tid % N
tid_list = self.getOIDs(count)
@@ -538,59 +575,20 @@ class StorageMySQSLdbTests(NeoTestBase):
def test_getReplicationTIDList(self):
tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids
- result = self.db.getReplicationTIDList(tid1, 4, 1, [0])
+ result = self.db.getReplicationTIDList(tid1, 4, 1, 0)
self.checkSet(result, [tid1, tid2, tid3, tid4])
- result = self.db.getReplicationTIDList(tid1, 4, 2, [0])
+ result = self.db.getReplicationTIDList(tid1, 4, 2, 0)
self.checkSet(result, [tid1, tid3])
- result = self.db.getReplicationTIDList(tid1, 4, 2, [0, 1])
- self.checkSet(result, [tid1, tid2, tid3, tid4])
- result = self.db.getReplicationTIDList(tid1, 4, 3, [0])
+ result = self.db.getReplicationTIDList(tid1, 4, 3, 0)
self.checkSet(result, [tid1, tid4])
# get a subset of tids
- result = self.db.getReplicationTIDList(tid3, 4, 1, [0])
+ result = self.db.getReplicationTIDList(tid3, 4, 1, 0)
self.checkSet(result, [tid3, tid4])
- result = self.db.getReplicationTIDList(tid1, 2, 1, [0])
+ result = self.db.getReplicationTIDList(tid1, 2, 1, 0)
self.checkSet(result, [tid1, tid2])
- result = self.db.getReplicationTIDList(tid1, 1, 3, [1])
+ result = self.db.getReplicationTIDList(tid1, 1, 3, 1)
self.checkSet(result, [tid2])
- def test_getTIDListPresent(self):
- oid = self.getOID(1)
- tid1, tid2, tid3, tid4 = self.getTIDs(4)
- txn1, objs1 = self.getTransaction([oid])
- txn4, objs4 = self.getTransaction([oid])
- # four tids, two missing
- self.db.storeTransaction(tid1, objs1, txn1)
- self.db.finishTransaction(tid1)
- self.db.storeTransaction(tid4, objs4, txn4)
- self.db.finishTransaction(tid4)
- result = self.db.getTIDListPresent([tid1, tid2, tid3, tid4])
- self.checkSet(result, [tid1, tid4])
- result = self.db.getTIDListPresent([tid1, tid2])
- self.checkSet(result, [tid1])
- self.assertEqual(self.db.getTIDListPresent([tid2, tid3]), [])
-
- def test_getSerialListPresent(self):
- oid1, oid2 = self.getOIDs(2)
- tid1, tid2, tid3, tid4 = self.getTIDs(4)
- txn1, objs1 = self.getTransaction([oid1])
- txn2, objs2 = self.getTransaction([oid1])
- txn3, objs3 = self.getTransaction([oid2])
- txn4, objs4 = self.getTransaction([oid2])
- # four object, one revision each
- self.db.storeTransaction(tid1, objs1, txn1)
- self.db.finishTransaction(tid1)
- self.db.storeTransaction(tid4, objs4, txn4)
- self.db.finishTransaction(tid4)
- result = self.db.getSerialListPresent(oid1, [tid1, tid2])
- self.checkSet(result, [tid1])
- result = self.db.getSerialListPresent(oid2, [tid3, tid4])
- self.checkSet(result, [tid4])
- result = self.db.getSerialListPresent(oid1, [tid2])
- self.assertEqual(result, [])
- result = self.db.getSerialListPresent(oid2, [tid3])
- self.assertEqual(result, [])
-
def test__getObjectData(self):
db = self.db
db.setup(reset=True)
Modified: trunk/neo/tests/testProtocol.py
==============================================================================
--- trunk/neo/tests/testProtocol.py [iso-8859-1] (original)
+++ trunk/neo/tests/testProtocol.py [iso-8859-1] Sun Sep 5 11:15:04 2010
@@ -458,24 +458,6 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(p_hist_list, hist_list)
self.assertEqual(oid, poid)
- def test_55_askOIDs(self):
- oid = self.getOID(1)
- p = Packets.AskOIDs(oid, 1000, 5)
- min_oid, length, partition = p.decode()
- self.assertEqual(min_oid, oid)
- self.assertEqual(length, 1000)
- self.assertEqual(partition, 5)
-
- def test_56_answerOIDs(self):
- oid1 = self.getNextTID()
- oid2 = self.getNextTID()
- oid3 = self.getNextTID()
- oid4 = self.getNextTID()
- oid_list = [oid1, oid2, oid3, oid4]
- p = Packets.AnswerOIDs(oid_list)
- p_oid_list = p.decode()[0]
- self.assertEqual(p_oid_list, oid_list)
-
def test_57_notifyReplicationDone(self):
offset = 10
p = Packets.NotifyReplicationDone(offset)
@@ -626,14 +608,82 @@ class ProtocolTests(NeoTestBase):
oid = self.getOID(1)
min_serial = self.getNextTID()
length = 5
- p = Packets.AskObjectHistoryFrom(oid, min_serial, length)
- p_oid, p_min_serial, p_length = p.decode()
+ partition = 4
+ p = Packets.AskObjectHistoryFrom(oid, min_serial, length, partition)
+ p_oid, p_min_serial, p_length, p_partition = p.decode()
self.assertEqual(p_oid, oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length)
+ self.assertEqual(p_partition, partition)
def test_AnswerObjectHistoryFrom(self):
- self._testXIDAndYIDList(Packets.AnswerObjectHistoryFrom)
+ object_dict = {}
+ for int_oid in xrange(4):
+ object_dict[self.getOID(int_oid)] = [self.getNextTID() \
+ for _ in xrange(5)]
+ p = Packets.AnswerObjectHistoryFrom(object_dict)
+ p_object_dict = p.decode()[0]
+ self.assertEqual(object_dict, p_object_dict)
+
+ def test_AskCheckTIDRange(self):
+ min_tid = self.getNextTID()
+ length = 2
+ partition = 4
+ p = Packets.AskCheckTIDRange(min_tid, length, partition)
+ p_min_tid, p_length, p_partition = p.decode()
+ self.assertEqual(p_min_tid, min_tid)
+ self.assertEqual(p_length, length)
+ self.assertEqual(p_partition, partition)
+
+ def test_AnswerCheckTIDRange(self):
+ min_tid = self.getNextTID()
+ length = 2
+ count = 1
+ tid_checksum = 42
+ max_tid = self.getNextTID()
+ p = Packets.AnswerCheckTIDRange(min_tid, length, count, tid_checksum,
+ max_tid)
+ p_min_tid, p_length, p_count, p_tid_checksum, p_max_tid = p.decode()
+ self.assertEqual(p_min_tid, min_tid)
+ self.assertEqual(p_length, length)
+ self.assertEqual(p_count, count)
+ self.assertEqual(p_tid_checksum, tid_checksum)
+ self.assertEqual(p_max_tid, max_tid)
+
+ def test_AskCheckSerialRange(self):
+ min_oid = self.getOID(1)
+ min_serial = self.getNextTID()
+ length = 2
+ partition = 4
+ p = Packets.AskCheckSerialRange(min_oid, min_serial, length, partition)
+ p_min_oid, p_min_serial, p_length, p_partition = p.decode()
+ self.assertEqual(p_min_oid, min_oid)
+ self.assertEqual(p_min_serial, min_serial)
+ self.assertEqual(p_length, length)
+ self.assertEqual(p_partition, partition)
+
+ def test_AnswerCheckSerialRange(self):
+ min_oid = self.getOID(1)
+ min_serial = self.getNextTID()
+ length = 2
+ count = 1
+ oid_checksum = 24
+ max_oid = self.getOID(5)
+ tid_checksum = 42
+ max_serial = self.getNextTID()
+ p = Packets.AnswerCheckSerialRange(min_oid, min_serial, length, count,
+ oid_checksum, max_oid, tid_checksum, max_serial)
+ p_min_oid, p_min_serial, p_length, p_count, p_oid_checksum, \
+ p_max_oid, p_tid_checksum, p_max_serial = p.decode()
+ self.assertEqual(p_min_oid, min_oid)
+ self.assertEqual(p_min_serial, min_serial)
+ self.assertEqual(p_length, length)
+ self.assertEqual(p_count, count)
+ self.assertEqual(p_oid_checksum, oid_checksum)
+ self.assertEqual(p_max_oid, max_oid)
+ self.assertEqual(p_tid_checksum, tid_checksum)
+ self.assertEqual(p_max_serial, max_serial)
+
def test_AskPack(self):
tid = self.getNextTID()
Modified: trunk/tools/runner
==============================================================================
--- trunk/tools/runner [iso-8859-1] (original)
+++ trunk/tools/runner [iso-8859-1] Sun Sep 5 11:15:04 2010
@@ -56,6 +56,8 @@ UNIT_TEST_MODULES = [
'neo.tests.storage.testVerificationHandler',
'neo.tests.storage.testIdentificationHandler',
'neo.tests.storage.testTransactions',
+ 'neo.tests.storage.testReplicationHandler',
+ 'neo.tests.storage.testReplicator',
# client application
'neo.tests.client.testClientApp',
'neo.tests.client.testMasterHandler',
More information about the Neo-report
mailing list