264 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			264 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import json
 | |
| import traceback
 | |
| from contextvars import copy_context
 | |
| import asyncio
 | |
| from functools import partial
 | |
| 
 | |
| from _plotly_utils.utils import PlotlyJSONEncoder
 | |
| 
 | |
| from dash._callback_context import context_value
 | |
| from dash._utils import AttributeDict
 | |
| from dash.exceptions import PreventUpdate
 | |
| from dash.background_callback._proxy_set_props import ProxySetProps
 | |
| from dash.background_callback.managers import BaseBackgroundCallbackManager
 | |
| 
 | |
| 
 | |
| class CeleryManager(BaseBackgroundCallbackManager):
 | |
|     """Manage background execution of callbacks with a celery queue."""
 | |
| 
 | |
|     def __init__(self, celery_app, cache_by=None, expire=None):
 | |
|         """
 | |
|         Background callback manager that runs callback logic on a celery task queue,
 | |
|         and stores results using a celery result backend.
 | |
| 
 | |
|         :param celery_app:
 | |
|             A celery.Celery application instance that must be configured with a
 | |
|             result backend. See the celery documentation for information on
 | |
|             configuration options.
 | |
|         :param cache_by:
 | |
|             A list of zero-argument functions.  When provided, caching is enabled and
 | |
|             the return values of these functions are combined with the callback
 | |
|             function's input arguments, triggered inputs and source code to generate cache keys.
 | |
|         :param expire:
 | |
|             If provided, a cache entry will be removed when it has not been accessed
 | |
|             for ``expire`` seconds.  If not provided, the lifetime of cache entries
 | |
|             is determined by the default behavior of the celery result backend.
 | |
|         """
 | |
|         try:
 | |
|             import celery  # type: ignore[reportMissingImports]; pylint: disable=import-outside-toplevel,import-error
 | |
|             from celery.backends.base import (  # type: ignore[reportMissingImports]; pylint: disable=import-outside-toplevel,import-error
 | |
|                 DisabledBackend,
 | |
|             )
 | |
|         except ImportError as missing_imports:
 | |
|             raise ImportError(
 | |
|                 """\
 | |
| CeleryManager requires extra dependencies which can be installed doing
 | |
| 
 | |
|     $ pip install "dash[celery]"\n"""
 | |
|             ) from missing_imports
 | |
| 
 | |
|         if not isinstance(celery_app, celery.Celery):
 | |
|             raise ValueError("First argument must be a celery.Celery object")
 | |
| 
 | |
|         if isinstance(celery_app.backend, DisabledBackend):
 | |
|             raise ValueError("Celery instance must be configured with a result backend")
 | |
| 
 | |
|         self.handle = celery_app
 | |
|         self.expire = expire
 | |
|         super().__init__(cache_by)
 | |
| 
 | |
|     def terminate_job(self, job):
 | |
|         if job is None:
 | |
|             return
 | |
| 
 | |
|         self.handle.control.terminate(job)
 | |
| 
 | |
|     def terminate_unhealthy_job(self, job):
 | |
|         task = self.get_task(job)
 | |
|         if task and task.status in ("FAILURE", "REVOKED"):
 | |
|             return self.terminate_job(job)
 | |
|         return False
 | |
| 
 | |
|     def job_running(self, job):
 | |
|         future = self.get_task(job)
 | |
|         return future and future.status in (
 | |
|             "PENDING",
 | |
|             "RECEIVED",
 | |
|             "STARTED",
 | |
|             "RETRY",
 | |
|             "PROGRESS",
 | |
|         )
 | |
| 
 | |
|     def make_job_fn(self, fn, progress, key=None):
 | |
|         return _make_job_fn(fn, self.handle, progress, key)
 | |
| 
 | |
|     def get_task(self, job):
 | |
|         if job:
 | |
|             return self.handle.AsyncResult(job)
 | |
| 
 | |
|         return None
 | |
| 
 | |
|     def clear_cache_entry(self, key):
 | |
|         self.handle.backend.delete(key)
 | |
| 
 | |
|     def call_job_fn(self, key, job_fn, args, context):
 | |
|         task = job_fn.delay(key, self._make_progress_key(key), args, context)
 | |
|         return task.task_id
 | |
| 
 | |
|     def get_progress(self, key):
 | |
|         progress_key = self._make_progress_key(key)
 | |
|         progress_data = self.handle.backend.get(progress_key)
 | |
|         if progress_data:
 | |
|             self.handle.backend.delete(progress_key)
 | |
|             return json.loads(progress_data)
 | |
| 
 | |
|         return None
 | |
| 
 | |
|     def result_ready(self, key):
 | |
|         return self.handle.backend.get(key) is not None
 | |
| 
 | |
|     def get_result(self, key, job):
 | |
|         # Get result value
 | |
|         result = self.handle.backend.get(key)
 | |
|         if result is None:
 | |
|             return self.UNDEFINED
 | |
| 
 | |
|         result = json.loads(result)
 | |
| 
 | |
|         # Clear result if not caching
 | |
|         if self.cache_by is None:
 | |
|             self.clear_cache_entry(key)
 | |
|         else:
 | |
|             if self.expire:
 | |
|                 # Set/update expiration time
 | |
|                 self.handle.backend.expire(key, self.expire)
 | |
|         self.clear_cache_entry(self._make_progress_key(key))
 | |
| 
 | |
|         self.terminate_job(job)
 | |
|         return result
 | |
| 
 | |
|     def get_updated_props(self, key):
 | |
|         updated_props = self.handle.backend.get(self._make_set_props_key(key))
 | |
|         if updated_props is None:
 | |
|             return {}
 | |
| 
 | |
|         self.clear_cache_entry(key)
 | |
| 
 | |
|         return json.loads(updated_props)
 | |
| 
 | |
| 
 | |
| def _make_job_fn(fn, celery_app, progress, key):  # pylint: disable=too-many-statements
 | |
|     cache = celery_app.backend
 | |
| 
 | |
|     @celery_app.task(name=f"background_callback_{key}")
 | |
|     def job_fn(
 | |
|         result_key, progress_key, user_callback_args, context=None
 | |
|     ):  # pylint: disable=too-many-statements
 | |
|         def _set_progress(progress_value):
 | |
|             if not isinstance(progress_value, (list, tuple)):
 | |
|                 progress_value = [progress_value]
 | |
| 
 | |
|             cache.set(progress_key, json.dumps(progress_value, cls=PlotlyJSONEncoder))
 | |
| 
 | |
|         maybe_progress = [_set_progress] if progress else []
 | |
| 
 | |
|         def _set_props(_id, props):
 | |
|             cache.set(
 | |
|                 f"{result_key}-set_props",
 | |
|                 json.dumps({_id: props}, cls=PlotlyJSONEncoder),
 | |
|             )
 | |
| 
 | |
|         ctx = copy_context()
 | |
| 
 | |
|         def run():
 | |
|             c = AttributeDict(**context)  # type: ignore[reportCallIssue]
 | |
|             c.ignore_register_page = False
 | |
|             c.updated_props = ProxySetProps(_set_props)
 | |
|             context_value.set(c)
 | |
|             errored = False
 | |
|             user_callback_output = None  # to help type checking
 | |
|             try:
 | |
|                 if isinstance(user_callback_args, dict):
 | |
|                     user_callback_output = fn(*maybe_progress, **user_callback_args)
 | |
|                 elif isinstance(user_callback_args, (list, tuple)):
 | |
|                     user_callback_output = fn(*maybe_progress, *user_callback_args)
 | |
|                 else:
 | |
|                     user_callback_output = fn(*maybe_progress, user_callback_args)
 | |
|             except PreventUpdate:
 | |
|                 # Put NoUpdate dict directly to avoid circular imports.
 | |
|                 errored = True
 | |
|                 cache.set(
 | |
|                     result_key,
 | |
|                     json.dumps(
 | |
|                         {"_dash_no_update": "_dash_no_update"}, cls=PlotlyJSONEncoder
 | |
|                     ),
 | |
|                 )
 | |
|             except Exception as err:  # pylint: disable=broad-except
 | |
|                 errored = True
 | |
|                 cache.set(
 | |
|                     result_key,
 | |
|                     json.dumps(
 | |
|                         {
 | |
|                             "background_callback_error": {
 | |
|                                 "msg": str(err),
 | |
|                                 "tb": traceback.format_exc(),
 | |
|                             }
 | |
|                         },
 | |
|                     ),
 | |
|                 )
 | |
| 
 | |
|             if not errored:
 | |
|                 cache.set(
 | |
|                     result_key, json.dumps(user_callback_output, cls=PlotlyJSONEncoder)
 | |
|                 )
 | |
| 
 | |
|         async def async_run():
 | |
|             c = AttributeDict(**context)
 | |
|             c.ignore_register_page = False
 | |
|             c.updated_props = ProxySetProps(_set_props)
 | |
|             context_value.set(c)
 | |
|             errored = False
 | |
|             try:
 | |
|                 if isinstance(user_callback_args, dict):
 | |
|                     user_callback_output = await fn(
 | |
|                         *maybe_progress, **user_callback_args
 | |
|                     )
 | |
|                 elif isinstance(user_callback_args, (list, tuple)):
 | |
|                     user_callback_output = await fn(
 | |
|                         *maybe_progress, *user_callback_args
 | |
|                     )
 | |
|                 else:
 | |
|                     user_callback_output = await fn(*maybe_progress, user_callback_args)
 | |
|             except PreventUpdate:
 | |
|                 # Put NoUpdate dict directly to avoid circular imports.
 | |
|                 errored = True
 | |
|                 cache.set(
 | |
|                     result_key,
 | |
|                     json.dumps(
 | |
|                         {"_dash_no_update": "_dash_no_update"}, cls=PlotlyJSONEncoder
 | |
|                     ),
 | |
|                 )
 | |
|             except Exception as err:  # pylint: disable=broad-except
 | |
|                 errored = True
 | |
|                 cache.set(
 | |
|                     result_key,
 | |
|                     json.dumps(
 | |
|                         {
 | |
|                             "background_callback_error": {
 | |
|                                 "msg": str(err),
 | |
|                                 "tb": traceback.format_exc(),
 | |
|                             }
 | |
|                         },
 | |
|                     ),
 | |
|                 )
 | |
| 
 | |
|             if asyncio.iscoroutine(user_callback_output):
 | |
|                 user_callback_output = await user_callback_output
 | |
| 
 | |
|             if not errored:
 | |
|                 cache.set(
 | |
|                     result_key, json.dumps(user_callback_output, cls=PlotlyJSONEncoder)
 | |
|                 )
 | |
| 
 | |
|         if asyncio.iscoroutinefunction(fn):
 | |
|             func = partial(ctx.run, async_run)
 | |
|             asyncio.run(func())
 | |
|         else:
 | |
|             ctx.run(run)
 | |
| 
 | |
|     return job_fn
 | |
| 
 | |
| 
 | |
| class CeleryLongCallbackManager(CeleryManager):
 | |
|     """Deprecated: use `from dash import CeleryManager` instead."""
 |