From 24d8d99cfe0e8da968a08c75a8349d7a42f3a481 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 13 Jun 2026 00:07:44 +0000 Subject: [PATCH 1/4] refactor(show): extract render-command collection and title normalization Decompose the show() god-function (#697). First, behavior-preserving steps: - pass _validate_show_parameters args by keyword so a signature reorder can no longer silently misvalidate one parameter as another - extract _collect_render_commands() and _normalize_title() module helpers No behavior change. --- src/spatialdata_plot/pl/basic.py | 132 ++++++++++++++++++------------- 1 file changed, 78 insertions(+), 54 deletions(-) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index f864482c..f6e90f18 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -1373,31 +1373,31 @@ def show( ) from e _validate_show_parameters( - coordinate_systems, - legend_fontsize, - legend_fontweight, - legend_loc, - legend_fontoutline, - na_in_legend, - colorbar, - colorbar_params, - wspace, - hspace, - ncols, - frameon, - figsize, - dpi, - fig, - title, - pad_extent, - ax, - return_ax, - save, - show, - scalebar_dx, - scalebar_units, - scalebar_params, - legend_params, + coordinate_systems=coordinate_systems, + legend_fontsize=legend_fontsize, + legend_fontweight=legend_fontweight, + legend_loc=legend_loc, + legend_fontoutline=legend_fontoutline, + na_in_legend=na_in_legend, + colorbar=colorbar, + colorbar_params=colorbar_params, + wspace=wspace, + hspace=hspace, + ncols=ncols, + frameon=frameon, + figsize=figsize, + dpi=dpi, + fig=fig, + title=title, + pad_extent=pad_extent, + ax=ax, + return_ax=return_ax, + save=save, + show=show, + scalebar_dx=scalebar_dx, + scalebar_units=scalebar_units, + scalebar_params=scalebar_params, + legend_params=legend_params, ) if fig is not None: @@ -1412,36 +1412,9 @@ def show( sdata = self._copy() # Evaluate execution tree for plotting - valid_commands = [ - "render_images", - "render_shapes", - "render_labels", - "render_points", - "render_graph", - ] - - # prepare rendering params - render_cmds = [] - for cmd, params in plotting_tree.items(): - # strip prefix from cmd and verify it's valid - cmd = "_".join(cmd.split("_")[1:]) - - if cmd not in valid_commands: - raise ValueError(f"Command {cmd} is not valid.") - - if "render" in cmd: - # verify that rendering commands have been called before - render_cmds.append((cmd, params)) - - if not render_cmds: - raise TypeError("Please specify what to plot using the 'render_*' functions before calling 'imshow()'.") + render_cmds = _collect_render_commands(plotting_tree) - if title is not None: - if isinstance(title, str): - title = [title] - - if not all(isinstance(t, str) for t in title): - raise TypeError("All titles must be strings.") + title = _normalize_title(title) # Track whether the caller supplied their own axes so we can skip # plt.show() later (ax is reassigned inside the rendering loop). @@ -1888,3 +1861,54 @@ def _draw_colorbar( if show: plt.show() return (fig_params.ax if fig_params.axs is None else fig_params.axs) if return_ax else None # shuts up ruff + + +# Render commands queued on the plotting tree, as ``(command_name, render_params)`` pairs. +_RenderCmd = tuple[ + str, + ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams | GraphRenderParams, +] + + +def _collect_render_commands(plotting_tree: OrderedDict[str, Any]) -> list[_RenderCmd]: + """Extract and validate the queued ``render_*`` commands from the plotting tree. + + Each tree key is prefixed with its execution step (e.g. ``"1_render_images"``); the prefix + is stripped and the bare command validated. Raises if an unknown command is present or if no + render command was queued at all. + """ + valid_commands = [ + "render_images", + "render_shapes", + "render_labels", + "render_points", + "render_graph", + ] + + render_cmds: list[_RenderCmd] = [] + for cmd, params in plotting_tree.items(): + # strip the step-index prefix from cmd and verify it's valid + cmd = "_".join(cmd.split("_")[1:]) + + if cmd not in valid_commands: + raise ValueError(f"Command {cmd} is not valid.") + + if "render" in cmd: + # verify that rendering commands have been called before + render_cmds.append((cmd, params)) + + if not render_cmds: + raise TypeError("Please specify what to plot using the 'render_*' functions before calling 'imshow()'.") + + return render_cmds + + +def _normalize_title(title: list[str] | str | None) -> list[str] | None: + """Normalize the ``title`` argument to a list of strings (or ``None``).""" + if title is None: + return None + if isinstance(title, str): + title = [title] + if not all(isinstance(t, str) for t in title): + raise TypeError("All titles must be strings.") + return title From fcdf1c0584f36a89cd305e3b3a65b899eb696625 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 13 Jun 2026 00:11:40 +0000 Subject: [PATCH 2/4] refactor(show): extract CS resolution, panel planning, legend & colorbar stages Continue decomposing show() (#697), behavior-preserving: - _resolve_coordinate_systems(): CS auto-detection, validation and filtering - _plan_panels(): panel layout (one-per-CS vs one-per-color-key) + ax-count check - _build_legend_params(): LegendParams construction with legend_params overrides - promote the _draw_colorbar closure to a module function taking colorbar_params explicitly (was a ~90-line closure capturing it implicitly) - _layout_pending_colorbars(): the deferred second-pass colorbar layout No behavior change. --- src/spatialdata_plot/pl/basic.py | 482 +++++++++++++++++++------------ 1 file changed, 290 insertions(+), 192 deletions(-) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index f6e90f18..62af8f5c 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -47,6 +47,7 @@ CmapParams, ColorbarSpec, ColorLike, + FigParams, GraphRenderParams, ImageRenderParams, LabelsRenderParams, @@ -1428,84 +1429,29 @@ def show( ax_x_min, ax_x_max = ax.get_xlim() ax_y_max, ax_y_min = ax.get_ylim() # (0, 0) is top-left - cs_was_auto = coordinate_systems is None - coordinate_systems = list(sdata.coordinate_systems) if cs_was_auto else coordinate_systems - if isinstance(coordinate_systems, str): - coordinate_systems = [coordinate_systems] - assert coordinate_systems is not None - - for cs in coordinate_systems: - if cs not in sdata.coordinate_systems: - raise ValueError(f"Unknown coordinate system '{cs}', valid choices are: {sdata.coordinate_systems}") - # Check if user specified only certain elements to be plotted cs_contents = _get_cs_contents(sdata) cs_index = cs_contents.set_index("cs") pending_colorbars: list[tuple[Axes, list[ColorbarSpec]]] = [] - elements_to_be_rendered = _get_elements_to_be_rendered(render_cmds, cs_index, cs) - - # filter out cs without relevant elements - cmds = [cmd for cmd, _ in render_cmds] - coordinate_systems = _get_valid_cs( + cs_was_auto = coordinate_systems is None + coordinate_systems = _resolve_coordinate_systems( sdata=sdata, coordinate_systems=coordinate_systems, - render_images="render_images" in cmds, - render_labels="render_labels" in cmds, - render_points="render_points" in cmds, - render_shapes="render_shapes" in cmds, - elements=elements_to_be_rendered, + cs_was_auto=cs_was_auto, + render_cmds=render_cmds, + cs_index=cs_index, + ax=ax, ) - # When CS was auto-detected and ax is provided, keep only CS that have - # element types for ALL render commands (workaround for upstream #176). - if ax is not None and cs_was_auto: - n_ax = 1 if isinstance(ax, Axes) else len(ax) - if len(coordinate_systems) > n_ax: - required_flags = [_RENDER_CMD_TO_CS_FLAG[cmd] for cmd in cmds if cmd in _RENDER_CMD_TO_CS_FLAG] - strict_cs = [ - cs_name - for cs_name in coordinate_systems - if cs_name in cs_index.index and all(cs_index.loc[cs_name][flag] for flag in required_flags) - ] - if strict_cs: - coordinate_systems = strict_cs - - # Determine the panel layout. Panels are normally one per coordinate system, but when a - # render_* call passed a list of color keys we instead lay out one panel per key within a - # single coordinate system (scanpy-style `color=[...]`). Render entries tagged with a - # `panel_key` belong to that key's panel; untagged entries are shared across all panels. - panel_keys: list[str] = [] - for _cmd, _params in render_cmds: - pkey = getattr(_params, "panel_key", None) - if pkey is not None and pkey not in panel_keys: - panel_keys.append(pkey) - if panel_keys: - if len(coordinate_systems) != 1: - raise ValueError( - "A list of color keys (multi-panel plotting) requires exactly one coordinate system, " - f"but {len(coordinate_systems)} were selected: {coordinate_systems}. " - "Pass `coordinate_systems=` to choose a single one." - ) - panels: list[tuple[str, str | None]] = [(coordinate_systems[0], key) for key in panel_keys] - else: - panels = [(cs, None) for cs in coordinate_systems] + panels = _plan_panels( + coordinate_systems=coordinate_systems, + render_cmds=render_cmds, + ax=ax, + cs_was_auto=cs_was_auto, + ) num_panels = len(panels) - if ax is not None: - n_ax = 1 if isinstance(ax, Axes) else len(ax) - if num_panels != n_ax: - msg = ( - f"Mismatch between number of matplotlib axes objects ({n_ax}) and number of panels ({num_panels})." - ) - if cs_was_auto: - msg += ( - " This can happen when elements have transformations to multiple " - "coordinate systems (e.g. after filter_by_coordinate_system). " - "Pass `coordinate_systems=` explicitly to select which ones to plot." - ) - raise ValueError(msg) - # set up canvas fig_params, scalebar_params_obj = _prepare_params_plot( num_panels=num_panels, @@ -1521,18 +1467,8 @@ def show( scalebar_units=scalebar_units, scalebar_kwargs=scalebar_params, ) - if legend_params: - legend_fontsize = legend_params.get("fontsize", legend_fontsize) - legend_fontweight = legend_params.get("fontweight", legend_fontweight) - # `loc` is matplotlib.Legend's native key; `location` aligns with colorbar/scalebar. - legend_loc = legend_params.get("location", legend_params.get("loc", legend_loc)) - legend_fontoutline = legend_params.get("fontoutline", legend_fontoutline) - na_in_legend = legend_params.get("na_in_legend", na_in_legend) - - if legend_loc == "on data": - raise ValueError("legend_loc='on data' is not supported in spatialdata-plot.") - - legend_params_obj = LegendParams( + legend_params_obj = _build_legend_params( + legend_params=legend_params, legend_fontsize=legend_fontsize, legend_fontweight=legend_fontweight, legend_loc=legend_loc, @@ -1543,96 +1479,6 @@ def show( outline_legend_title=outline_legend_title, ) - def _draw_colorbar( - spec: ColorbarSpec, - fig: Figure, - renderer: RendererBase, - base_offsets_axes: dict[str, float], - trackers_axes: dict[str, float], - ) -> None: - norm = spec.mappable.norm - if isinstance(norm, LogNorm): - vmin, vmax = norm.vmin, norm.vmax - if vmin is None or vmax is None or vmin <= 0 or vmin >= vmax: - warnings.warn( - "Data contains zeros or non-positive values; colorbar suppressed for `LogNorm`. " - "Pass `colorbar=False` to silence this warning, or clip the data to positive values.", - UserWarning, - stacklevel=2, - ) - return - - base_layout = { - "location": CBAR_DEFAULT_LOCATION, - "fraction": CBAR_DEFAULT_FRACTION, - "pad": CBAR_DEFAULT_PAD, - } - layer_layout, layer_kwargs, layer_label_override = _split_colorbar_params(spec.params) - global_layout, global_kwargs, global_label_override = _split_colorbar_params(colorbar_params) - layout = {**base_layout, **layer_layout, **global_layout} - cbar_kwargs = {**layer_kwargs, **global_kwargs} - - location = cast(str, layout.get("location", base_layout["location"])) - if location not in {"left", "right", "top", "bottom"}: - location = CBAR_DEFAULT_LOCATION - default_orientation = "vertical" if location in {"right", "left"} else "horizontal" - cbar_kwargs.setdefault("orientation", default_orientation) - - fraction = float(cast(float | int, layout.get("fraction", base_layout["fraction"]))) - pad = float(cast(float | int, layout.get("pad", base_layout["pad"]))) - - if location in {"left", "right"}: - pad_axes = pad + trackers_axes[location] - x0 = -pad_axes - fraction if location == "left" else 1 + pad_axes - bbox = (float(x0), 0.0, float(fraction), 1.0) - else: - pad_axes = pad + trackers_axes[location] - y0 = -pad_axes - fraction if location == "bottom" else 1 + pad_axes - bbox = (0.0, float(y0), 1.0, float(fraction)) - cax = inset_axes( - spec.ax, - width="100%", - height="100%", - loc="center", - bbox_to_anchor=bbox, - bbox_transform=spec.ax.transAxes, - borderpad=0.0, - ) - - cb = fig.colorbar(spec.mappable, cax=cax, **cbar_kwargs) - if location == "left": - cb.ax.yaxis.set_ticks_position("left") - cb.ax.yaxis.set_label_position("left") - cb.ax.tick_params(labelleft=True, labelright=False) - elif location == "top": - cb.ax.xaxis.set_ticks_position("top") - cb.ax.xaxis.set_label_position("top") - cb.ax.tick_params(labeltop=True, labelbottom=False) - elif location == "right": - cb.ax.yaxis.set_ticks_position("right") - cb.ax.yaxis.set_label_position("right") - cb.ax.tick_params(labelright=True, labelleft=False) - elif location == "bottom": - cb.ax.xaxis.set_ticks_position("bottom") - cb.ax.xaxis.set_label_position("bottom") - cb.ax.tick_params(labelbottom=True, labeltop=False) - - final_label = global_label_override or layer_label_override or spec.label - if final_label: - cb.set_label(final_label) - if spec.alpha is not None: - with contextlib.suppress(Exception): - cb.solids.set_alpha(spec.alpha) - bbox_axes = cb.ax.get_tightbbox(renderer).transformed(spec.ax.transAxes.inverted()) - if location == "left": - trackers_axes["left"] = pad_axes + bbox_axes.width - elif location == "right": - trackers_axes["right"] = pad_axes + bbox_axes.width - elif location == "bottom": - trackers_axes["bottom"] = pad_axes + bbox_axes.height - elif location == "top": - trackers_axes["top"] = pad_axes + bbox_axes.height - # go through tree for i, (cs, panel_key) in enumerate(panels): @@ -1822,29 +1668,7 @@ def _draw_colorbar( _draw_scalebar(ax, scalebar_params_obj, panel_idx=i) - if pending_colorbars and fig_params.fig is not None: - fig = fig_params.fig - fig.canvas.draw() - renderer = fig.canvas.get_renderer() - for axis, requests in pending_colorbars: - unique_specs: list[ColorbarSpec] = [] - seen_mappables: set[int] = set() - for spec in requests: - mappable_id = id(spec.mappable) - if mappable_id in seen_mappables: - continue - seen_mappables.add(mappable_id) - unique_specs.append(spec) - tight_bbox = axis.get_tightbbox(renderer).transformed(axis.transAxes.inverted()) - base_offsets_axes = { - "left": max(0.0, -tight_bbox.x0), - "right": max(0.0, tight_bbox.x1 - 1), - "bottom": max(0.0, -tight_bbox.y0), - "top": max(0.0, tight_bbox.y1 - 1), - } - trackers_axes = {k: base_offsets_axes[k] for k in base_offsets_axes} - for spec in unique_specs: - _draw_colorbar(spec, fig, renderer, base_offsets_axes, trackers_axes) + _layout_pending_colorbars(pending_colorbars, fig_params, colorbar_params) if fig_params.fig is not None and save is not None: save_fig(fig_params.fig, path=save) @@ -1912,3 +1736,277 @@ def _normalize_title(title: list[str] | str | None) -> list[str] | None: if not all(isinstance(t, str) for t in title): raise TypeError("All titles must be strings.") return title + + +def _resolve_coordinate_systems( + sdata: sd.SpatialData, + coordinate_systems: list[str] | str | None, + cs_was_auto: bool, + render_cmds: list[_RenderCmd], + cs_index: pd.DataFrame, + ax: list[Axes] | Axes | None, +) -> list[str]: + """Resolve, validate and filter the coordinate systems to render. + + Auto-detects all coordinate systems when ``coordinate_systems is None``, validates the + requested names, then drops systems that hold no element relevant to the queued render + commands. When axes are supplied alongside auto-detection, narrows to systems carrying an + element for *every* render command (workaround for upstream #176). + """ + coordinate_systems = list(sdata.coordinate_systems) if cs_was_auto else coordinate_systems + if isinstance(coordinate_systems, str): + coordinate_systems = [coordinate_systems] + assert coordinate_systems is not None + + for cs in coordinate_systems: + if cs not in sdata.coordinate_systems: + raise ValueError(f"Unknown coordinate system '{cs}', valid choices are: {sdata.coordinate_systems}") + + elements_to_be_rendered = _get_elements_to_be_rendered(render_cmds, cs_index, cs) + + # filter out cs without relevant elements + cmds = [cmd for cmd, _ in render_cmds] + coordinate_systems = _get_valid_cs( + sdata=sdata, + coordinate_systems=coordinate_systems, + render_images="render_images" in cmds, + render_labels="render_labels" in cmds, + render_points="render_points" in cmds, + render_shapes="render_shapes" in cmds, + elements=elements_to_be_rendered, + ) + + # When CS was auto-detected and ax is provided, keep only CS that have + # element types for ALL render commands (workaround for upstream #176). + if ax is not None and cs_was_auto: + n_ax = 1 if isinstance(ax, Axes) else len(ax) + if len(coordinate_systems) > n_ax: + required_flags = [_RENDER_CMD_TO_CS_FLAG[cmd] for cmd in cmds if cmd in _RENDER_CMD_TO_CS_FLAG] + strict_cs = [ + cs_name + for cs_name in coordinate_systems + if cs_name in cs_index.index and all(cs_index.loc[cs_name][flag] for flag in required_flags) + ] + if strict_cs: + coordinate_systems = strict_cs + + return coordinate_systems + + +def _plan_panels( + coordinate_systems: list[str], + render_cmds: list[_RenderCmd], + ax: list[Axes] | Axes | None, + cs_was_auto: bool, +) -> list[tuple[str, str | None]]: + """Determine the panel layout as ``(coordinate_system, panel_key)`` tuples. + + Panels are normally one per coordinate system, but when a render_* call passed a list of + color keys we instead lay out one panel per key within a single coordinate system + (scanpy-style ``color=[...]``). Render entries tagged with a ``panel_key`` belong to that + key's panel; untagged entries are shared across all panels. Also validates the panel count + against any user-supplied axes. + """ + panel_keys: list[str] = [] + for _cmd, _params in render_cmds: + pkey = getattr(_params, "panel_key", None) + if pkey is not None and pkey not in panel_keys: + panel_keys.append(pkey) + if panel_keys: + if len(coordinate_systems) != 1: + raise ValueError( + "A list of color keys (multi-panel plotting) requires exactly one coordinate system, " + f"but {len(coordinate_systems)} were selected: {coordinate_systems}. " + "Pass `coordinate_systems=` to choose a single one." + ) + panels: list[tuple[str, str | None]] = [(coordinate_systems[0], key) for key in panel_keys] + else: + panels = [(cs, None) for cs in coordinate_systems] + + if ax is not None: + n_ax = 1 if isinstance(ax, Axes) else len(ax) + if len(panels) != n_ax: + msg = f"Mismatch between number of matplotlib axes objects ({n_ax}) and number of panels ({len(panels)})." + if cs_was_auto: + msg += ( + " This can happen when elements have transformations to multiple " + "coordinate systems (e.g. after filter_by_coordinate_system). " + "Pass `coordinate_systems=` explicitly to select which ones to plot." + ) + raise ValueError(msg) + + return panels + + +def _build_legend_params( + legend_params: dict[str, Any] | None, + legend_fontsize: int | float | _FontSize | None, + legend_fontweight: int | _FontWeight, + legend_loc: str | None, + legend_fontoutline: int | None, + na_in_legend: bool, + colorbar: bool, + legend_title: str | None, + outline_legend_title: str | None, +) -> LegendParams: + """Build the :class:`LegendParams` bundle, applying the ``legend_params`` dict overrides. + + Keys in the ``legend_params`` dict take precedence over the matching flat ``legend_*`` + keyword arguments. + """ + if legend_params: + legend_fontsize = legend_params.get("fontsize", legend_fontsize) + legend_fontweight = legend_params.get("fontweight", legend_fontweight) + # `loc` is matplotlib.Legend's native key; `location` aligns with colorbar/scalebar. + legend_loc = legend_params.get("location", legend_params.get("loc", legend_loc)) + legend_fontoutline = legend_params.get("fontoutline", legend_fontoutline) + na_in_legend = legend_params.get("na_in_legend", na_in_legend) + + if legend_loc == "on data": + raise ValueError("legend_loc='on data' is not supported in spatialdata-plot.") + + return LegendParams( + legend_fontsize=legend_fontsize, + legend_fontweight=legend_fontweight, + legend_loc=legend_loc, + legend_fontoutline=legend_fontoutline, + na_in_legend=na_in_legend, + colorbar=colorbar, + legend_title=legend_title, + outline_legend_title=outline_legend_title, + ) + + +def _draw_colorbar( + spec: ColorbarSpec, + fig: Figure, + renderer: RendererBase, + base_offsets_axes: dict[str, float], + trackers_axes: dict[str, float], + colorbar_params: dict[str, object] | None, +) -> None: + """Draw a single colorbar inset against ``spec.ax`` and update the side-offset trackers. + + ``trackers_axes`` accumulates, per side, how far out colorbars already extend so successive + colorbars on the same side stack instead of overlapping. + """ + norm = spec.mappable.norm + if isinstance(norm, LogNorm): + vmin, vmax = norm.vmin, norm.vmax + if vmin is None or vmax is None or vmin <= 0 or vmin >= vmax: + warnings.warn( + "Data contains zeros or non-positive values; colorbar suppressed for `LogNorm`. " + "Pass `colorbar=False` to silence this warning, or clip the data to positive values.", + UserWarning, + stacklevel=2, + ) + return + + base_layout = { + "location": CBAR_DEFAULT_LOCATION, + "fraction": CBAR_DEFAULT_FRACTION, + "pad": CBAR_DEFAULT_PAD, + } + layer_layout, layer_kwargs, layer_label_override = _split_colorbar_params(spec.params) + global_layout, global_kwargs, global_label_override = _split_colorbar_params(colorbar_params) + layout = {**base_layout, **layer_layout, **global_layout} + cbar_kwargs = {**layer_kwargs, **global_kwargs} + + location = cast(str, layout.get("location", base_layout["location"])) + if location not in {"left", "right", "top", "bottom"}: + location = CBAR_DEFAULT_LOCATION + default_orientation = "vertical" if location in {"right", "left"} else "horizontal" + cbar_kwargs.setdefault("orientation", default_orientation) + + fraction = float(cast(float | int, layout.get("fraction", base_layout["fraction"]))) + pad = float(cast(float | int, layout.get("pad", base_layout["pad"]))) + + if location in {"left", "right"}: + pad_axes = pad + trackers_axes[location] + x0 = -pad_axes - fraction if location == "left" else 1 + pad_axes + bbox = (float(x0), 0.0, float(fraction), 1.0) + else: + pad_axes = pad + trackers_axes[location] + y0 = -pad_axes - fraction if location == "bottom" else 1 + pad_axes + bbox = (0.0, float(y0), 1.0, float(fraction)) + cax = inset_axes( + spec.ax, + width="100%", + height="100%", + loc="center", + bbox_to_anchor=bbox, + bbox_transform=spec.ax.transAxes, + borderpad=0.0, + ) + + cb = fig.colorbar(spec.mappable, cax=cax, **cbar_kwargs) + if location == "left": + cb.ax.yaxis.set_ticks_position("left") + cb.ax.yaxis.set_label_position("left") + cb.ax.tick_params(labelleft=True, labelright=False) + elif location == "top": + cb.ax.xaxis.set_ticks_position("top") + cb.ax.xaxis.set_label_position("top") + cb.ax.tick_params(labeltop=True, labelbottom=False) + elif location == "right": + cb.ax.yaxis.set_ticks_position("right") + cb.ax.yaxis.set_label_position("right") + cb.ax.tick_params(labelright=True, labelleft=False) + elif location == "bottom": + cb.ax.xaxis.set_ticks_position("bottom") + cb.ax.xaxis.set_label_position("bottom") + cb.ax.tick_params(labelbottom=True, labeltop=False) + + final_label = global_label_override or layer_label_override or spec.label + if final_label: + cb.set_label(final_label) + if spec.alpha is not None: + with contextlib.suppress(Exception): + cb.solids.set_alpha(spec.alpha) + bbox_axes = cb.ax.get_tightbbox(renderer).transformed(spec.ax.transAxes.inverted()) + if location == "left": + trackers_axes["left"] = pad_axes + bbox_axes.width + elif location == "right": + trackers_axes["right"] = pad_axes + bbox_axes.width + elif location == "bottom": + trackers_axes["bottom"] = pad_axes + bbox_axes.height + elif location == "top": + trackers_axes["top"] = pad_axes + bbox_axes.height + + +def _layout_pending_colorbars( + pending_colorbars: list[tuple[Axes, list[ColorbarSpec]]], + fig_params: FigParams, + colorbar_params: dict[str, object] | None, +) -> None: + """Second-pass colorbar layout, run once the canvas geometry is known. + + For each axis, deduplicates colorbar requests by mappable, seeds the per-side offset + trackers from the axis' tight bounding box, and draws each colorbar so they stack outward + without overlapping the axis content. + """ + if not (pending_colorbars and fig_params.fig is not None): + return + + fig = fig_params.fig + fig.canvas.draw() + renderer = fig.canvas.get_renderer() + for axis, requests in pending_colorbars: + unique_specs: list[ColorbarSpec] = [] + seen_mappables: set[int] = set() + for spec in requests: + mappable_id = id(spec.mappable) + if mappable_id in seen_mappables: + continue + seen_mappables.add(mappable_id) + unique_specs.append(spec) + tight_bbox = axis.get_tightbbox(renderer).transformed(axis.transAxes.inverted()) + base_offsets_axes = { + "left": max(0.0, -tight_bbox.x0), + "right": max(0.0, tight_bbox.x1 - 1), + "bottom": max(0.0, -tight_bbox.y0), + "top": max(0.0, tight_bbox.y1 - 1), + } + trackers_axes = {k: base_offsets_axes[k] for k in base_offsets_axes} + for spec in unique_specs: + _draw_colorbar(spec, fig, renderer, base_offsets_axes, trackers_axes, colorbar_params) From 6e70bc58bf94047d8945e3300ebd61c34de81d4b Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 13 Jun 2026 00:16:30 +0000 Subject: [PATCH 3/4] refactor(show): extract per-panel render dispatch into _render_panel/_finalize_panel Extract the inner render-command dispatch loop (#697), behavior-preserving: - _render_panel(): dispatches each queued render command into one panel's axes, returning the wanted elements and per-type wants_* flags - _finalize_panel(): per-panel title / equal-aspect / frame visibility show() is now a compact orchestrator calling named stages. Full test suite passes unchanged (719 passed, 1 skipped), image baselines included. No behavior change. --- src/spatialdata_plot/pl/basic.py | 332 ++++++++++++++++++------------- 1 file changed, 199 insertions(+), 133 deletions(-) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 62af8f5c..f0469c84 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -1493,139 +1493,25 @@ def show( axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params_obj.colorbar else None axis_channel_legend_entries: list[ChannelLegendEntry] = [] - wants_images = False - wants_labels = False - wants_points = False - wants_shapes = False - wanted_elements: list[str] = [] - - for cmd, params in render_cmds: - # Skip render entries that belong to a different color panel. Entries with no - # `panel_key` (None) are shared and drawn into every panel (e.g. a background image). - cmd_panel_key = getattr(params, "panel_key", None) - if panel_key is not None and cmd_panel_key is not None and cmd_panel_key != panel_key: - continue - # We create a copy here as the wanted elements can change from one cs to another. - params_copy = deepcopy(params) - if cmd == "render_images" and has_images: - wanted_elements, wanted_images_on_this_cs, wants_images = _get_wanted_render_elements( - sdata, wanted_elements, params_copy, cs, "images" - ) - - if wanted_images_on_this_cs: - rasterize = (params_copy.scale is None) or ( - isinstance(params_copy.scale, str) - and params_copy.scale != "full" - and (dpi is not None or figsize is not None) - ) - _render_images( - sdata=sdata, - render_params=params_copy, - coordinate_system=cs, - ax=ax, - fig_params=fig_params, - legend_params=legend_params_obj, - colorbar_requests=axis_colorbar_requests, - channel_legend_entries=axis_channel_legend_entries, - rasterize=rasterize, - ) - - elif cmd == "render_shapes" and has_shapes: - wanted_elements, wanted_shapes_on_this_cs, wants_shapes = _get_wanted_render_elements( - sdata, wanted_elements, params_copy, cs, "shapes" - ) - - if wanted_shapes_on_this_cs: - _render_shapes( - sdata=sdata, - render_params=params_copy, - coordinate_system=cs, - ax=ax, - fig_params=fig_params, - legend_params=legend_params_obj, - colorbar_requests=axis_colorbar_requests, - ) - - elif cmd == "render_points" and has_points: - wanted_elements, wanted_points_on_this_cs, wants_points = _get_wanted_render_elements( - sdata, wanted_elements, params_copy, cs, "points" - ) - - if wanted_points_on_this_cs: - _render_points( - sdata=sdata, - render_params=params_copy, - coordinate_system=cs, - ax=ax, - fig_params=fig_params, - legend_params=legend_params_obj, - colorbar_requests=axis_colorbar_requests, - ) - - elif cmd == "render_labels" and has_labels: - wanted_elements, wanted_labels_on_this_cs, wants_labels = _get_wanted_render_elements( - sdata, wanted_elements, params_copy, cs, "labels" - ) - - if wanted_labels_on_this_cs: - table = params_copy.table_name - if table is not None and params_copy.col_for_color is not None: - colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color]) - if isinstance( - colors[params_copy.col_for_color].dtype, - pd.CategoricalDtype, - ): - _maybe_set_colors( - source=sdata[table], - target=sdata[table], - key=params_copy.col_for_color, - palette=params_copy.palette, - ) - - rasterize = (params_copy.scale is None) or ( - isinstance(params_copy.scale, str) - and params_copy.scale != "full" - and (dpi is not None or figsize is not None) - ) - _render_labels( - sdata=sdata, - render_params=params_copy, - coordinate_system=cs, - ax=ax, - fig_params=fig_params, - legend_params=legend_params_obj, - colorbar_requests=axis_colorbar_requests, - rasterize=rasterize, - ) - - elif cmd == "render_graph": - graph_element = params_copy.element - element_in_cs = graph_element in sdata and cs in set( - get_transformation(sdata[graph_element], get_all=True).keys() - ) - if element_in_cs: - _render_graph( - sdata=sdata, - render_params=params_copy, - coordinate_system=cs, - ax=ax, - legend_params=legend_params_obj, - colorbar_requests=axis_colorbar_requests, - ) - - if title is None: - t = panel_key if panel_key is not None else cs - elif len(title) == 1: - t = title[0] - else: - try: - t = title[i] - except IndexError as e: - raise IndexError("The number of titles must match the number of panels.") from e - ax.set_title(t) - ax.set_aspect("equal") - if fig_params.frameon is False: - ax.axis("off") + wanted_elements, wants_images, wants_labels, wants_points, wants_shapes = _render_panel( + sdata=sdata, + render_cmds=render_cmds, + cs=cs, + panel_key=panel_key, + panel_idx=i, + ax=ax, + fig_params=fig_params, + legend_params_obj=legend_params_obj, + axis_colorbar_requests=axis_colorbar_requests, + axis_channel_legend_entries=axis_channel_legend_entries, + has_images=has_images, + has_labels=has_labels, + has_points=has_points, + has_shapes=has_shapes, + title=title, + dpi=dpi, + figsize=figsize, + ) if has_shapes and wants_shapes: empty_shape_elements = [ @@ -2010,3 +1896,183 @@ def _layout_pending_colorbars( trackers_axes = {k: base_offsets_axes[k] for k in base_offsets_axes} for spec in unique_specs: _draw_colorbar(spec, fig, renderer, base_offsets_axes, trackers_axes, colorbar_params) + + +def _finalize_panel( + ax: Axes, + panel_idx: int, + title: list[str] | None, + panel_key: str | None, + cs: str, + frameon: bool | None, +) -> None: + """Set a panel's title, equal aspect ratio and frame visibility. + + With no explicit ``title`` the panel is labelled with its color key (multi-panel color mode) + or its coordinate-system name; a single-element list applies to every panel, otherwise the + title at ``panel_idx`` is used. + """ + if title is None: + t = panel_key if panel_key is not None else cs + elif len(title) == 1: + t = title[0] + else: + try: + t = title[panel_idx] + except IndexError as e: + raise IndexError("The number of titles must match the number of panels.") from e + ax.set_title(t) + ax.set_aspect("equal") + if frameon is False: + ax.axis("off") + + +def _render_panel( + sdata: sd.SpatialData, + render_cmds: list[_RenderCmd], + cs: str, + panel_key: str | None, + panel_idx: int, + ax: Axes, + fig_params: FigParams, + legend_params_obj: LegendParams, + axis_colorbar_requests: list[ColorbarSpec] | None, + axis_channel_legend_entries: list[ChannelLegendEntry], + has_images: bool, + has_labels: bool, + has_points: bool, + has_shapes: bool, + title: list[str] | None, + dpi: int | None, + figsize: tuple[float, float] | None, +) -> tuple[list[str], bool, bool, bool, bool]: + """Render every applicable render command into a single panel's axes. + + Dispatches each queued render command to its ``_render_*`` function, skipping entries that + belong to a different color panel (``panel_key``). Colorbar requests and channel-legend + entries accumulate on the passed-in lists. Returns the wanted element names and the per-type + ``wants_*`` flags needed downstream for extent computation. + """ + wants_images = False + wants_labels = False + wants_points = False + wants_shapes = False + wanted_elements: list[str] = [] + + for cmd, params in render_cmds: + # Skip render entries that belong to a different color panel. Entries with no + # `panel_key` (None) are shared and drawn into every panel (e.g. a background image). + cmd_panel_key = getattr(params, "panel_key", None) + if panel_key is not None and cmd_panel_key is not None and cmd_panel_key != panel_key: + continue + # We create a copy here as the wanted elements can change from one cs to another. + params_copy = deepcopy(params) + if cmd == "render_images" and has_images: + wanted_elements, wanted_images_on_this_cs, wants_images = _get_wanted_render_elements( + sdata, wanted_elements, params_copy, cs, "images" + ) + + if wanted_images_on_this_cs: + rasterize = (params_copy.scale is None) or ( + isinstance(params_copy.scale, str) + and params_copy.scale != "full" + and (dpi is not None or figsize is not None) + ) + _render_images( + sdata=sdata, + render_params=params_copy, + coordinate_system=cs, + ax=ax, + fig_params=fig_params, + legend_params=legend_params_obj, + colorbar_requests=axis_colorbar_requests, + channel_legend_entries=axis_channel_legend_entries, + rasterize=rasterize, + ) + + elif cmd == "render_shapes" and has_shapes: + wanted_elements, wanted_shapes_on_this_cs, wants_shapes = _get_wanted_render_elements( + sdata, wanted_elements, params_copy, cs, "shapes" + ) + + if wanted_shapes_on_this_cs: + _render_shapes( + sdata=sdata, + render_params=params_copy, + coordinate_system=cs, + ax=ax, + fig_params=fig_params, + legend_params=legend_params_obj, + colorbar_requests=axis_colorbar_requests, + ) + + elif cmd == "render_points" and has_points: + wanted_elements, wanted_points_on_this_cs, wants_points = _get_wanted_render_elements( + sdata, wanted_elements, params_copy, cs, "points" + ) + + if wanted_points_on_this_cs: + _render_points( + sdata=sdata, + render_params=params_copy, + coordinate_system=cs, + ax=ax, + fig_params=fig_params, + legend_params=legend_params_obj, + colorbar_requests=axis_colorbar_requests, + ) + + elif cmd == "render_labels" and has_labels: + wanted_elements, wanted_labels_on_this_cs, wants_labels = _get_wanted_render_elements( + sdata, wanted_elements, params_copy, cs, "labels" + ) + + if wanted_labels_on_this_cs: + table = params_copy.table_name + if table is not None and params_copy.col_for_color is not None: + colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color]) + if isinstance( + colors[params_copy.col_for_color].dtype, + pd.CategoricalDtype, + ): + _maybe_set_colors( + source=sdata[table], + target=sdata[table], + key=params_copy.col_for_color, + palette=params_copy.palette, + ) + + rasterize = (params_copy.scale is None) or ( + isinstance(params_copy.scale, str) + and params_copy.scale != "full" + and (dpi is not None or figsize is not None) + ) + _render_labels( + sdata=sdata, + render_params=params_copy, + coordinate_system=cs, + ax=ax, + fig_params=fig_params, + legend_params=legend_params_obj, + colorbar_requests=axis_colorbar_requests, + rasterize=rasterize, + ) + + elif cmd == "render_graph": + graph_element = params_copy.element + element_in_cs = graph_element in sdata and cs in set( + get_transformation(sdata[graph_element], get_all=True).keys() + ) + if element_in_cs: + _render_graph( + sdata=sdata, + render_params=params_copy, + coordinate_system=cs, + ax=ax, + legend_params=legend_params_obj, + colorbar_requests=axis_colorbar_requests, + ) + + _finalize_panel(ax, panel_idx, title, panel_key, cs, fig_params.frameon) + + return wanted_elements, wants_images, wants_labels, wants_points, wants_shapes From 83a3c7d0ecfb2f3dabb9f4d6ea4c1d95d41caa3e Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 13 Jun 2026 00:28:07 +0000 Subject: [PATCH 4/4] refactor(show): tighten extracted helpers, cut ~70 LOC Post-review cleanups (all behavior-preserving, full suite green 719 passed): - drop dead _draw_colorbar param base_offsets_axes (never read) - collapse the two parallel 4-branch location chains in _draw_colorbar into data-driven lookups (vertical flag + opposite map + getattr on the axis) - extract _should_rasterize() to dedup the images/labels rasterize heuristic - extract _maybe_set_label_colors() for the categorical-color prestep - replace the 5-branch render dispatch in _render_panel with a renderer table keyed by command; graph stays a small special case - pass cs_row instead of four has_* booleans; return a wants dict instead of a four-boolean tuple (removes the parallel-variable sprawl) --- src/spatialdata_plot/pl/basic.py | 235 +++++++++++-------------------- 1 file changed, 82 insertions(+), 153 deletions(-) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index f0469c84..b2f81142 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -1493,7 +1493,7 @@ def show( axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params_obj.colorbar else None axis_channel_legend_entries: list[ChannelLegendEntry] = [] - wanted_elements, wants_images, wants_labels, wants_points, wants_shapes = _render_panel( + wanted_elements, wants = _render_panel( sdata=sdata, render_cmds=render_cmds, cs=cs, @@ -1504,16 +1504,13 @@ def show( legend_params_obj=legend_params_obj, axis_colorbar_requests=axis_colorbar_requests, axis_channel_legend_entries=axis_channel_legend_entries, - has_images=has_images, - has_labels=has_labels, - has_points=has_points, - has_shapes=has_shapes, + cs_row=cs_row, title=title, dpi=dpi, figsize=figsize, ) - if has_shapes and wants_shapes: + if has_shapes and wants["shapes"]: empty_shape_elements = [ name for name in wanted_elements @@ -1528,10 +1525,10 @@ def show( extent = get_extent( sdata, coordinate_system=cs, - has_images=has_images and wants_images, - has_labels=has_labels and wants_labels, - has_points=has_points and wants_points, - has_shapes=has_shapes and wants_shapes, + has_images=has_images and wants["images"], + has_labels=has_labels and wants["labels"], + has_points=has_points and wants["points"], + has_shapes=has_shapes and wants["shapes"], elements=wanted_elements, ) cs_x_min, cs_x_max = extent["x"] @@ -1767,7 +1764,6 @@ def _draw_colorbar( spec: ColorbarSpec, fig: Figure, renderer: RendererBase, - base_offsets_axes: dict[str, float], trackers_axes: dict[str, float], colorbar_params: dict[str, object] | None, ) -> None: @@ -1807,14 +1803,11 @@ def _draw_colorbar( fraction = float(cast(float | int, layout.get("fraction", base_layout["fraction"]))) pad = float(cast(float | int, layout.get("pad", base_layout["pad"]))) - if location in {"left", "right"}: - pad_axes = pad + trackers_axes[location] - x0 = -pad_axes - fraction if location == "left" else 1 + pad_axes - bbox = (float(x0), 0.0, float(fraction), 1.0) - else: - pad_axes = pad + trackers_axes[location] - y0 = -pad_axes - fraction if location == "bottom" else 1 + pad_axes - bbox = (0.0, float(y0), 1.0, float(fraction)) + vertical = location in {"left", "right"} + pad_axes = pad + trackers_axes[location] + # "left"/"bottom" grow outward in the negative direction; "right"/"top" past 1. + start = -pad_axes - fraction if location in {"left", "bottom"} else 1 + pad_axes + bbox = (float(start), 0.0, float(fraction), 1.0) if vertical else (0.0, float(start), 1.0, float(fraction)) cax = inset_axes( spec.ax, width="100%", @@ -1826,22 +1819,11 @@ def _draw_colorbar( ) cb = fig.colorbar(spec.mappable, cax=cax, **cbar_kwargs) - if location == "left": - cb.ax.yaxis.set_ticks_position("left") - cb.ax.yaxis.set_label_position("left") - cb.ax.tick_params(labelleft=True, labelright=False) - elif location == "top": - cb.ax.xaxis.set_ticks_position("top") - cb.ax.xaxis.set_label_position("top") - cb.ax.tick_params(labeltop=True, labelbottom=False) - elif location == "right": - cb.ax.yaxis.set_ticks_position("right") - cb.ax.yaxis.set_label_position("right") - cb.ax.tick_params(labelright=True, labelleft=False) - elif location == "bottom": - cb.ax.xaxis.set_ticks_position("bottom") - cb.ax.xaxis.set_label_position("bottom") - cb.ax.tick_params(labelbottom=True, labeltop=False) + opposite = {"left": "right", "right": "left", "top": "bottom", "bottom": "top"}[location] + cbar_axis = cb.ax.yaxis if vertical else cb.ax.xaxis + cbar_axis.set_ticks_position(location) + cbar_axis.set_label_position(location) + cb.ax.tick_params(**{f"label{location}": True, f"label{opposite}": False}) final_label = global_label_override or layer_label_override or spec.label if final_label: @@ -1850,14 +1832,7 @@ def _draw_colorbar( with contextlib.suppress(Exception): cb.solids.set_alpha(spec.alpha) bbox_axes = cb.ax.get_tightbbox(renderer).transformed(spec.ax.transAxes.inverted()) - if location == "left": - trackers_axes["left"] = pad_axes + bbox_axes.width - elif location == "right": - trackers_axes["right"] = pad_axes + bbox_axes.width - elif location == "bottom": - trackers_axes["bottom"] = pad_axes + bbox_axes.height - elif location == "top": - trackers_axes["top"] = pad_axes + bbox_axes.height + trackers_axes[location] = pad_axes + (bbox_axes.width if vertical else bbox_axes.height) def _layout_pending_colorbars( @@ -1887,15 +1862,14 @@ def _layout_pending_colorbars( seen_mappables.add(mappable_id) unique_specs.append(spec) tight_bbox = axis.get_tightbbox(renderer).transformed(axis.transAxes.inverted()) - base_offsets_axes = { + trackers_axes = { "left": max(0.0, -tight_bbox.x0), "right": max(0.0, tight_bbox.x1 - 1), "bottom": max(0.0, -tight_bbox.y0), "top": max(0.0, tight_bbox.y1 - 1), } - trackers_axes = {k: base_offsets_axes[k] for k in base_offsets_axes} for spec in unique_specs: - _draw_colorbar(spec, fig, renderer, base_offsets_axes, trackers_axes, colorbar_params) + _draw_colorbar(spec, fig, renderer, trackers_axes, colorbar_params) def _finalize_panel( @@ -1927,6 +1901,31 @@ def _finalize_panel( ax.axis("off") +def _should_rasterize( + render_params: ImageRenderParams | LabelsRenderParams, + dpi: int | None, + figsize: tuple[float, float] | None, +) -> bool: + """Rasterize when no scale is set, or a non-``"full"`` scale is paired with a fixed canvas size.""" + scale = render_params.scale + return scale is None or (isinstance(scale, str) and scale != "full" and (dpi is not None or figsize is not None)) + + +def _maybe_set_label_colors(sdata: sd.SpatialData, render_params: LabelsRenderParams) -> None: + """Materialize a categorical palette on the table annotating a labels element, if applicable.""" + table = render_params.table_name + if table is None or render_params.col_for_color is None: + return + colors = sc.get.obs_df(sdata[table], [render_params.col_for_color]) + if isinstance(colors[render_params.col_for_color].dtype, pd.CategoricalDtype): + _maybe_set_colors( + source=sdata[table], + target=sdata[table], + key=render_params.col_for_color, + palette=render_params.palette, + ) + + def _render_panel( sdata: sd.SpatialData, render_cmds: list[_RenderCmd], @@ -1938,25 +1937,26 @@ def _render_panel( legend_params_obj: LegendParams, axis_colorbar_requests: list[ColorbarSpec] | None, axis_channel_legend_entries: list[ChannelLegendEntry], - has_images: bool, - has_labels: bool, - has_points: bool, - has_shapes: bool, + cs_row: pd.Series, title: list[str] | None, dpi: int | None, figsize: tuple[float, float] | None, -) -> tuple[list[str], bool, bool, bool, bool]: +) -> tuple[list[str], dict[str, bool]]: """Render every applicable render command into a single panel's axes. Dispatches each queued render command to its ``_render_*`` function, skipping entries that belong to a different color panel (``panel_key``). Colorbar requests and channel-legend - entries accumulate on the passed-in lists. Returns the wanted element names and the per-type - ``wants_*`` flags needed downstream for extent computation. + entries accumulate on the passed-in lists. Returns the wanted element names and a per-type + ``wants`` flag dict (keyed ``"images"``/``"labels"``/``"points"``/``"shapes"``) needed + downstream for extent computation. """ - wants_images = False - wants_labels = False - wants_points = False - wants_shapes = False + renderers = { + "render_images": _render_images, + "render_shapes": _render_shapes, + "render_points": _render_points, + "render_labels": _render_labels, + } + wants = dict.fromkeys(("images", "labels", "points", "shapes"), False) wanted_elements: list[str] = [] for cmd, params in render_cmds: @@ -1967,103 +1967,10 @@ def _render_panel( continue # We create a copy here as the wanted elements can change from one cs to another. params_copy = deepcopy(params) - if cmd == "render_images" and has_images: - wanted_elements, wanted_images_on_this_cs, wants_images = _get_wanted_render_elements( - sdata, wanted_elements, params_copy, cs, "images" - ) - - if wanted_images_on_this_cs: - rasterize = (params_copy.scale is None) or ( - isinstance(params_copy.scale, str) - and params_copy.scale != "full" - and (dpi is not None or figsize is not None) - ) - _render_images( - sdata=sdata, - render_params=params_copy, - coordinate_system=cs, - ax=ax, - fig_params=fig_params, - legend_params=legend_params_obj, - colorbar_requests=axis_colorbar_requests, - channel_legend_entries=axis_channel_legend_entries, - rasterize=rasterize, - ) - - elif cmd == "render_shapes" and has_shapes: - wanted_elements, wanted_shapes_on_this_cs, wants_shapes = _get_wanted_render_elements( - sdata, wanted_elements, params_copy, cs, "shapes" - ) - if wanted_shapes_on_this_cs: - _render_shapes( - sdata=sdata, - render_params=params_copy, - coordinate_system=cs, - ax=ax, - fig_params=fig_params, - legend_params=legend_params_obj, - colorbar_requests=axis_colorbar_requests, - ) - - elif cmd == "render_points" and has_points: - wanted_elements, wanted_points_on_this_cs, wants_points = _get_wanted_render_elements( - sdata, wanted_elements, params_copy, cs, "points" - ) - - if wanted_points_on_this_cs: - _render_points( - sdata=sdata, - render_params=params_copy, - coordinate_system=cs, - ax=ax, - fig_params=fig_params, - legend_params=legend_params_obj, - colorbar_requests=axis_colorbar_requests, - ) - - elif cmd == "render_labels" and has_labels: - wanted_elements, wanted_labels_on_this_cs, wants_labels = _get_wanted_render_elements( - sdata, wanted_elements, params_copy, cs, "labels" - ) - - if wanted_labels_on_this_cs: - table = params_copy.table_name - if table is not None and params_copy.col_for_color is not None: - colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color]) - if isinstance( - colors[params_copy.col_for_color].dtype, - pd.CategoricalDtype, - ): - _maybe_set_colors( - source=sdata[table], - target=sdata[table], - key=params_copy.col_for_color, - palette=params_copy.palette, - ) - - rasterize = (params_copy.scale is None) or ( - isinstance(params_copy.scale, str) - and params_copy.scale != "full" - and (dpi is not None or figsize is not None) - ) - _render_labels( - sdata=sdata, - render_params=params_copy, - coordinate_system=cs, - ax=ax, - fig_params=fig_params, - legend_params=legend_params_obj, - colorbar_requests=axis_colorbar_requests, - rasterize=rasterize, - ) - - elif cmd == "render_graph": + if cmd == "render_graph": graph_element = params_copy.element - element_in_cs = graph_element in sdata and cs in set( - get_transformation(sdata[graph_element], get_all=True).keys() - ) - if element_in_cs: + if graph_element in sdata and cs in get_transformation(sdata[graph_element], get_all=True): _render_graph( sdata=sdata, render_params=params_copy, @@ -2072,7 +1979,29 @@ def _render_panel( legend_params=legend_params_obj, colorbar_requests=axis_colorbar_requests, ) + elif cmd in renderers and cs_row[_RENDER_CMD_TO_CS_FLAG[cmd]]: + element_type = cmd.removeprefix("render_") + wanted_elements, wanted_on_cs, wants[element_type] = _get_wanted_render_elements( + sdata, wanted_elements, params_copy, cs, element_type + ) + if wanted_on_cs: + kwargs: dict[str, Any] = { + "sdata": sdata, + "render_params": params_copy, + "coordinate_system": cs, + "ax": ax, + "fig_params": fig_params, + "legend_params": legend_params_obj, + "colorbar_requests": axis_colorbar_requests, + } + if cmd == "render_images": + kwargs["channel_legend_entries"] = axis_channel_legend_entries + if cmd in {"render_images", "render_labels"}: + kwargs["rasterize"] = _should_rasterize(params_copy, dpi, figsize) + if cmd == "render_labels": + _maybe_set_label_colors(sdata, params_copy) + renderers[cmd](**kwargs) _finalize_panel(ax, panel_idx, title, panel_key, cs, fig_params.frameon) - return wanted_elements, wants_images, wants_labels, wants_points, wants_shapes + return wanted_elements, wants