Source code for bndl.compute.broadcast

from uuid import uuid4
import concurrent.futures
import json
import logging
import marshal
import pickle

from bndl.util import serialize, threads
from bndl.util.conf import Float
from bndl.util.funcs import identity
from cytoolz.functoolz import compose


min_block_size = Float(4, desc='The maximum size of a block in megabytes.')  # MB
max_block_size = Float(16, desc='The minimum size of a block in megabytes.')  # MB


logger = logging.getLogger(__name__)
download_coordinator = threads.Coordinator()

MISSING = 'bndl.compute.broadcast.MISSING'


[docs]def broadcast(ctx, value, serialization='auto', deserialization=None): ''' Broadcast data to workers. Args: value (object): The value to broadcast. serialization (str): The format to serialize the broadcast value into. Must be one of auto, pickle, marshal, json, binary or text. deserialization (None or function(bytes)): Data can be 'shipped' along to workers in the closure of e.g. a mapper function, but in that case the data is sent once for every partition (task to be precise). For 'larger' values this may be wasteful. Use this for instance with lookup tables of e.g. a MB or more. Note that the broadcast data is loaded on each worker (but only if the broadcast variable is used). The machine running the workers should thus have enough memory to spare. If deserialization is set serialization must *not* be set and value must be of type `bytes`. Otherwise serialize is used to serialize value and its natural deserialization counterpart is used (e.g. bytes.decode followed by json.loads for the 'json' serialization format). Example usage:: >>> tbl = ctx.broadcast(dict(zip(range(4), 'abcd'))) >>> ctx.range(4).map(lambda i: tbl.value[i]).collect() ['a', 'b', 'c', 'd'] ''' if serialization is not None: if deserialization is not None: raise ValueError("Can't specify both serialization and deserialization") elif serialization == 'auto': marshalled, data = serialize.dumps(value) deserialization = marshal.loads if marshalled else pickle.loads elif serialization == 'pickle': data = pickle.dumps(value) deserialization = pickle.loads elif serialization == 'marshal': data = marshal.dumps(value) deserialization = marshal.loads elif serialization == 'json': data = json.dumps(value).encode() deserialization = compose(json.loads, bytes.decode) elif serialization == 'binary': data = value deserialization = identity elif serialization == 'text': data = value.encode() deserialization = bytes.decode else: raise ValueError('Unsupported serialization %s' % serialization) elif not deserialization: raise ValueError('Must specify either serialization or deserialization') else: data = value key = str(uuid4()) min_block_size = int(ctx.conf.get('bndl.compute.broadcast.min_block_size') * 1024 * 1024) max_block_size = int(ctx.conf.get('bndl.compute.broadcast.max_block_size') * 1024 * 1024) if min_block_size == max_block_size: block_spec = ctx.node.serve_data(key, data, max_block_size) else: block_spec = ctx.node.serve_data(key, data, (ctx.worker_count * 2, min_block_size, max_block_size)) return BroadcastValue(ctx, ctx.node.name, block_spec, deserialization)
class BroadcastValue(object): def __init__(self, ctx, seeder, block_spec, deserialize): self.ctx = ctx self.seeder = seeder self.block_spec = block_spec self.deserialize = deserialize @property def value(self): return download_coordinator.coordinate(self._get, self.block_spec.name) def _get(self): node = self.ctx.node blocks = node.get_blocks(self.block_spec, node.peers.filter(node_type='worker')) val = self.deserialize(b''.join(blocks)) if node.name != self.block_spec.seeder: node.remove_blocks(self.block_spec.name, from_peers=False) return val def unpersist(self, block=False, timeout=None): node = self.ctx.node name = self.block_spec.name assert node.name == self.block_spec.seeder node.unpersist_broadcast_values(node, name) requests = [peer.unpersist_broadcast_values for peer in node.peers.filter()] if timeout: requests = [request.with_timeout(timeout) for request in requests] requests = [request(name) for request in requests] if block: for request in requests: try: request.result() except concurrent.futures.TimeoutError: pass except Exception: logger.warning('error while unpersisting %s', name, exc_info=True) def __del__(self): if self.ctx.node and self.ctx.node.name == self.block_spec.seeder: self.unpersist()