Source code for orion.executor.dask_backend

import traceback

from orion.executor.base import (
    AsyncException,
    AsyncResult,
    BaseExecutor,
    ExecutorClosed,
    Future,
)

try:
    import dask.distributed
    from dask.distributed import Client, get_client, get_worker, rejoin, secede

    HAS_DASK = True
except ImportError:
    HAS_DASK = False


class _Future(Future):
    """Wraps a Dask Future"""

    def __init__(self, future):
        self.future = future
        self.exception = None

    def get(self, timeout=None):
        if self.exception:
            raise self.exception

        try:
            return self.future.result(timeout)
        except dask.distributed.TimeoutError as e:
            raise TimeoutError() from e

    def wait(self, timeout=None):
        try:
            self.future.result(timeout)
        except dask.distributed.TimeoutError:
            pass
        except Exception as e:
            self.exception = e

    def ready(self):
        return self.future.done()

    def successful(self):
        if not self.future.done():
            raise ValueError()

        return self.future.exception() is None


[docs]class Dask(BaseExecutor): """Wrapper around the dask client. .. warning:: The Dask executor can be pickled and used inside a subprocess, the pickled client will use the main client that was spawned in the main process, but you cannot spawn clients inside a subprocess. """ def __init__(self, n_workers=-1, client=None, **config): super().__init__(n_workers=n_workers) if not HAS_DASK: raise ImportError("Dask must be installed to use Dask executor.") self.config = config if client is None: client = Client(**self.config) self.client = client def __getstate__(self): return super().__getstate__() def __setstate__(self, state): super().__setstate__(state) self.client = get_client() @property def in_worker(self): try: get_worker() return True except ValueError: return False
[docs] def wait(self, futures): if self.in_worker: secede() results = self.client.gather(list(futures)) if self.in_worker: rejoin() return [r.get() for r in results]
[docs] def async_get(self, futures, timeout=0.01): results = [] tobe_deleted = [] for i, future in enumerate(futures): if timeout and i == 0: future.wait(timeout) if future.ready(): try: results.append(AsyncResult(future, future.get())) except Exception as err: results.append(AsyncException(future, err, traceback.format_exc())) tobe_deleted.append(future) for future in tobe_deleted: futures.remove(future) return results
[docs] def submit(self, function, *args, **kwargs): try: return _Future(self.client.submit(function, *args, **kwargs, pure=False)) except Exception as e: if str(e).startswith( "Tried sending message after closing. Status: closed" ): raise ExecutorClosed() from e raise
def __del__(self): # This is necessary because if the factory constructor fails # __del__ is executed right away but client might not be set if hasattr(self, "client"): self.client.close() def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.client.close() super().__exit__(exc_type, exc_value, traceback)