diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index f864482c..b2f81142 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -47,6 +47,7 @@ CmapParams, ColorbarSpec, ColorLike, + FigParams, GraphRenderParams, ImageRenderParams, LabelsRenderParams, @@ -1373,31 +1374,31 @@ def show( ) from e _validate_show_parameters( - coordinate_systems, - legend_fontsize, - legend_fontweight, - legend_loc, - legend_fontoutline, - na_in_legend, - colorbar, - colorbar_params, - wspace, - hspace, - ncols, - frameon, - figsize, - dpi, - fig, - title, - pad_extent, - ax, - return_ax, - save, - show, - scalebar_dx, - scalebar_units, - scalebar_params, - legend_params, + coordinate_systems=coordinate_systems, + legend_fontsize=legend_fontsize, + legend_fontweight=legend_fontweight, + legend_loc=legend_loc, + legend_fontoutline=legend_fontoutline, + na_in_legend=na_in_legend, + colorbar=colorbar, + colorbar_params=colorbar_params, + wspace=wspace, + hspace=hspace, + ncols=ncols, + frameon=frameon, + figsize=figsize, + dpi=dpi, + fig=fig, + title=title, + pad_extent=pad_extent, + ax=ax, + return_ax=return_ax, + save=save, + show=show, + scalebar_dx=scalebar_dx, + scalebar_units=scalebar_units, + scalebar_params=scalebar_params, + legend_params=legend_params, ) if fig is not None: @@ -1412,36 +1413,9 @@ def show( sdata = self._copy() # Evaluate execution tree for plotting - valid_commands = [ - "render_images", - "render_shapes", - "render_labels", - "render_points", - "render_graph", - ] - - # prepare rendering params - render_cmds = [] - for cmd, params in plotting_tree.items(): - # strip prefix from cmd and verify it's valid - cmd = "_".join(cmd.split("_")[1:]) - - if cmd not in valid_commands: - raise ValueError(f"Command {cmd} is not valid.") - - if "render" in cmd: - # verify that rendering commands have been called before - render_cmds.append((cmd, params)) - - if not render_cmds: - raise TypeError("Please specify what to plot using the 'render_*' functions before calling 'imshow()'.") - - if title is not None: - if isinstance(title, str): - title = [title] + render_cmds = _collect_render_commands(plotting_tree) - if not all(isinstance(t, str) for t in title): - raise TypeError("All titles must be strings.") + title = _normalize_title(title) # Track whether the caller supplied their own axes so we can skip # plt.show() later (ax is reassigned inside the rendering loop). @@ -1455,84 +1429,29 @@ def show( ax_x_min, ax_x_max = ax.get_xlim() ax_y_max, ax_y_min = ax.get_ylim() # (0, 0) is top-left - cs_was_auto = coordinate_systems is None - coordinate_systems = list(sdata.coordinate_systems) if cs_was_auto else coordinate_systems - if isinstance(coordinate_systems, str): - coordinate_systems = [coordinate_systems] - assert coordinate_systems is not None - - for cs in coordinate_systems: - if cs not in sdata.coordinate_systems: - raise ValueError(f"Unknown coordinate system '{cs}', valid choices are: {sdata.coordinate_systems}") - # Check if user specified only certain elements to be plotted cs_contents = _get_cs_contents(sdata) cs_index = cs_contents.set_index("cs") pending_colorbars: list[tuple[Axes, list[ColorbarSpec]]] = [] - elements_to_be_rendered = _get_elements_to_be_rendered(render_cmds, cs_index, cs) - - # filter out cs without relevant elements - cmds = [cmd for cmd, _ in render_cmds] - coordinate_systems = _get_valid_cs( + cs_was_auto = coordinate_systems is None + coordinate_systems = _resolve_coordinate_systems( sdata=sdata, coordinate_systems=coordinate_systems, - render_images="render_images" in cmds, - render_labels="render_labels" in cmds, - render_points="render_points" in cmds, - render_shapes="render_shapes" in cmds, - elements=elements_to_be_rendered, + cs_was_auto=cs_was_auto, + render_cmds=render_cmds, + cs_index=cs_index, + ax=ax, ) - # When CS was auto-detected and ax is provided, keep only CS that have - # element types for ALL render commands (workaround for upstream #176). - if ax is not None and cs_was_auto: - n_ax = 1 if isinstance(ax, Axes) else len(ax) - if len(coordinate_systems) > n_ax: - required_flags = [_RENDER_CMD_TO_CS_FLAG[cmd] for cmd in cmds if cmd in _RENDER_CMD_TO_CS_FLAG] - strict_cs = [ - cs_name - for cs_name in coordinate_systems - if cs_name in cs_index.index and all(cs_index.loc[cs_name][flag] for flag in required_flags) - ] - if strict_cs: - coordinate_systems = strict_cs - - # Determine the panel layout. Panels are normally one per coordinate system, but when a - # render_* call passed a list of color keys we instead lay out one panel per key within a - # single coordinate system (scanpy-style `color=[...]`). Render entries tagged with a - # `panel_key` belong to that key's panel; untagged entries are shared across all panels. - panel_keys: list[str] = [] - for _cmd, _params in render_cmds: - pkey = getattr(_params, "panel_key", None) - if pkey is not None and pkey not in panel_keys: - panel_keys.append(pkey) - if panel_keys: - if len(coordinate_systems) != 1: - raise ValueError( - "A list of color keys (multi-panel plotting) requires exactly one coordinate system, " - f"but {len(coordinate_systems)} were selected: {coordinate_systems}. " - "Pass `coordinate_systems=` to choose a single one." - ) - panels: list[tuple[str, str | None]] = [(coordinate_systems[0], key) for key in panel_keys] - else: - panels = [(cs, None) for cs in coordinate_systems] + panels = _plan_panels( + coordinate_systems=coordinate_systems, + render_cmds=render_cmds, + ax=ax, + cs_was_auto=cs_was_auto, + ) num_panels = len(panels) - if ax is not None: - n_ax = 1 if isinstance(ax, Axes) else len(ax) - if num_panels != n_ax: - msg = ( - f"Mismatch between number of matplotlib axes objects ({n_ax}) and number of panels ({num_panels})." - ) - if cs_was_auto: - msg += ( - " This can happen when elements have transformations to multiple " - "coordinate systems (e.g. after filter_by_coordinate_system). " - "Pass `coordinate_systems=` explicitly to select which ones to plot." - ) - raise ValueError(msg) - # set up canvas fig_params, scalebar_params_obj = _prepare_params_plot( num_panels=num_panels, @@ -1548,18 +1467,8 @@ def show( scalebar_units=scalebar_units, scalebar_kwargs=scalebar_params, ) - if legend_params: - legend_fontsize = legend_params.get("fontsize", legend_fontsize) - legend_fontweight = legend_params.get("fontweight", legend_fontweight) - # `loc` is matplotlib.Legend's native key; `location` aligns with colorbar/scalebar. - legend_loc = legend_params.get("location", legend_params.get("loc", legend_loc)) - legend_fontoutline = legend_params.get("fontoutline", legend_fontoutline) - na_in_legend = legend_params.get("na_in_legend", na_in_legend) - - if legend_loc == "on data": - raise ValueError("legend_loc='on data' is not supported in spatialdata-plot.") - - legend_params_obj = LegendParams( + legend_params_obj = _build_legend_params( + legend_params=legend_params, legend_fontsize=legend_fontsize, legend_fontweight=legend_fontweight, legend_loc=legend_loc, @@ -1570,96 +1479,6 @@ def show( outline_legend_title=outline_legend_title, ) - def _draw_colorbar( - spec: ColorbarSpec, - fig: Figure, - renderer: RendererBase, - base_offsets_axes: dict[str, float], - trackers_axes: dict[str, float], - ) -> None: - norm = spec.mappable.norm - if isinstance(norm, LogNorm): - vmin, vmax = norm.vmin, norm.vmax - if vmin is None or vmax is None or vmin <= 0 or vmin >= vmax: - warnings.warn( - "Data contains zeros or non-positive values; colorbar suppressed for `LogNorm`. " - "Pass `colorbar=False` to silence this warning, or clip the data to positive values.", - UserWarning, - stacklevel=2, - ) - return - - base_layout = { - "location": CBAR_DEFAULT_LOCATION, - "fraction": CBAR_DEFAULT_FRACTION, - "pad": CBAR_DEFAULT_PAD, - } - layer_layout, layer_kwargs, layer_label_override = _split_colorbar_params(spec.params) - global_layout, global_kwargs, global_label_override = _split_colorbar_params(colorbar_params) - layout = {**base_layout, **layer_layout, **global_layout} - cbar_kwargs = {**layer_kwargs, **global_kwargs} - - location = cast(str, layout.get("location", base_layout["location"])) - if location not in {"left", "right", "top", "bottom"}: - location = CBAR_DEFAULT_LOCATION - default_orientation = "vertical" if location in {"right", "left"} else "horizontal" - cbar_kwargs.setdefault("orientation", default_orientation) - - fraction = float(cast(float | int, layout.get("fraction", base_layout["fraction"]))) - pad = float(cast(float | int, layout.get("pad", base_layout["pad"]))) - - if location in {"left", "right"}: - pad_axes = pad + trackers_axes[location] - x0 = -pad_axes - fraction if location == "left" else 1 + pad_axes - bbox = (float(x0), 0.0, float(fraction), 1.0) - else: - pad_axes = pad + trackers_axes[location] - y0 = -pad_axes - fraction if location == "bottom" else 1 + pad_axes - bbox = (0.0, float(y0), 1.0, float(fraction)) - cax = inset_axes( - spec.ax, - width="100%", - height="100%", - loc="center", - bbox_to_anchor=bbox, - bbox_transform=spec.ax.transAxes, - borderpad=0.0, - ) - - cb = fig.colorbar(spec.mappable, cax=cax, **cbar_kwargs) - if location == "left": - cb.ax.yaxis.set_ticks_position("left") - cb.ax.yaxis.set_label_position("left") - cb.ax.tick_params(labelleft=True, labelright=False) - elif location == "top": - cb.ax.xaxis.set_ticks_position("top") - cb.ax.xaxis.set_label_position("top") - cb.ax.tick_params(labeltop=True, labelbottom=False) - elif location == "right": - cb.ax.yaxis.set_ticks_position("right") - cb.ax.yaxis.set_label_position("right") - cb.ax.tick_params(labelright=True, labelleft=False) - elif location == "bottom": - cb.ax.xaxis.set_ticks_position("bottom") - cb.ax.xaxis.set_label_position("bottom") - cb.ax.tick_params(labelbottom=True, labeltop=False) - - final_label = global_label_override or layer_label_override or spec.label - if final_label: - cb.set_label(final_label) - if spec.alpha is not None: - with contextlib.suppress(Exception): - cb.solids.set_alpha(spec.alpha) - bbox_axes = cb.ax.get_tightbbox(renderer).transformed(spec.ax.transAxes.inverted()) - if location == "left": - trackers_axes["left"] = pad_axes + bbox_axes.width - elif location == "right": - trackers_axes["right"] = pad_axes + bbox_axes.width - elif location == "bottom": - trackers_axes["bottom"] = pad_axes + bbox_axes.height - elif location == "top": - trackers_axes["top"] = pad_axes + bbox_axes.height - # go through tree for i, (cs, panel_key) in enumerate(panels): @@ -1674,141 +1493,24 @@ def _draw_colorbar( axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params_obj.colorbar else None axis_channel_legend_entries: list[ChannelLegendEntry] = [] - wants_images = False - wants_labels = False - wants_points = False - wants_shapes = False - wanted_elements: list[str] = [] - - for cmd, params in render_cmds: - # Skip render entries that belong to a different color panel. Entries with no - # `panel_key` (None) are shared and drawn into every panel (e.g. a background image). - cmd_panel_key = getattr(params, "panel_key", None) - if panel_key is not None and cmd_panel_key is not None and cmd_panel_key != panel_key: - continue - # We create a copy here as the wanted elements can change from one cs to another. - params_copy = deepcopy(params) - if cmd == "render_images" and has_images: - wanted_elements, wanted_images_on_this_cs, wants_images = _get_wanted_render_elements( - sdata, wanted_elements, params_copy, cs, "images" - ) - - if wanted_images_on_this_cs: - rasterize = (params_copy.scale is None) or ( - isinstance(params_copy.scale, str) - and params_copy.scale != "full" - and (dpi is not None or figsize is not None) - ) - _render_images( - sdata=sdata, - render_params=params_copy, - coordinate_system=cs, - ax=ax, - fig_params=fig_params, - legend_params=legend_params_obj, - colorbar_requests=axis_colorbar_requests, - channel_legend_entries=axis_channel_legend_entries, - rasterize=rasterize, - ) - - elif cmd == "render_shapes" and has_shapes: - wanted_elements, wanted_shapes_on_this_cs, wants_shapes = _get_wanted_render_elements( - sdata, wanted_elements, params_copy, cs, "shapes" - ) - - if wanted_shapes_on_this_cs: - _render_shapes( - sdata=sdata, - render_params=params_copy, - coordinate_system=cs, - ax=ax, - fig_params=fig_params, - legend_params=legend_params_obj, - colorbar_requests=axis_colorbar_requests, - ) - - elif cmd == "render_points" and has_points: - wanted_elements, wanted_points_on_this_cs, wants_points = _get_wanted_render_elements( - sdata, wanted_elements, params_copy, cs, "points" - ) - - if wanted_points_on_this_cs: - _render_points( - sdata=sdata, - render_params=params_copy, - coordinate_system=cs, - ax=ax, - fig_params=fig_params, - legend_params=legend_params_obj, - colorbar_requests=axis_colorbar_requests, - ) - - elif cmd == "render_labels" and has_labels: - wanted_elements, wanted_labels_on_this_cs, wants_labels = _get_wanted_render_elements( - sdata, wanted_elements, params_copy, cs, "labels" - ) - - if wanted_labels_on_this_cs: - table = params_copy.table_name - if table is not None and params_copy.col_for_color is not None: - colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color]) - if isinstance( - colors[params_copy.col_for_color].dtype, - pd.CategoricalDtype, - ): - _maybe_set_colors( - source=sdata[table], - target=sdata[table], - key=params_copy.col_for_color, - palette=params_copy.palette, - ) - - rasterize = (params_copy.scale is None) or ( - isinstance(params_copy.scale, str) - and params_copy.scale != "full" - and (dpi is not None or figsize is not None) - ) - _render_labels( - sdata=sdata, - render_params=params_copy, - coordinate_system=cs, - ax=ax, - fig_params=fig_params, - legend_params=legend_params_obj, - colorbar_requests=axis_colorbar_requests, - rasterize=rasterize, - ) - - elif cmd == "render_graph": - graph_element = params_copy.element - element_in_cs = graph_element in sdata and cs in set( - get_transformation(sdata[graph_element], get_all=True).keys() - ) - if element_in_cs: - _render_graph( - sdata=sdata, - render_params=params_copy, - coordinate_system=cs, - ax=ax, - legend_params=legend_params_obj, - colorbar_requests=axis_colorbar_requests, - ) + wanted_elements, wants = _render_panel( + sdata=sdata, + render_cmds=render_cmds, + cs=cs, + panel_key=panel_key, + panel_idx=i, + ax=ax, + fig_params=fig_params, + legend_params_obj=legend_params_obj, + axis_colorbar_requests=axis_colorbar_requests, + axis_channel_legend_entries=axis_channel_legend_entries, + cs_row=cs_row, + title=title, + dpi=dpi, + figsize=figsize, + ) - if title is None: - t = panel_key if panel_key is not None else cs - elif len(title) == 1: - t = title[0] - else: - try: - t = title[i] - except IndexError as e: - raise IndexError("The number of titles must match the number of panels.") from e - ax.set_title(t) - ax.set_aspect("equal") - if fig_params.frameon is False: - ax.axis("off") - - if has_shapes and wants_shapes: + if has_shapes and wants["shapes"]: empty_shape_elements = [ name for name in wanted_elements @@ -1823,10 +1525,10 @@ def _draw_colorbar( extent = get_extent( sdata, coordinate_system=cs, - has_images=has_images and wants_images, - has_labels=has_labels and wants_labels, - has_points=has_points and wants_points, - has_shapes=has_shapes and wants_shapes, + has_images=has_images and wants["images"], + has_labels=has_labels and wants["labels"], + has_points=has_points and wants["points"], + has_shapes=has_shapes and wants["shapes"], elements=wanted_elements, ) cs_x_min, cs_x_max = extent["x"] @@ -1849,29 +1551,7 @@ def _draw_colorbar( _draw_scalebar(ax, scalebar_params_obj, panel_idx=i) - if pending_colorbars and fig_params.fig is not None: - fig = fig_params.fig - fig.canvas.draw() - renderer = fig.canvas.get_renderer() - for axis, requests in pending_colorbars: - unique_specs: list[ColorbarSpec] = [] - seen_mappables: set[int] = set() - for spec in requests: - mappable_id = id(spec.mappable) - if mappable_id in seen_mappables: - continue - seen_mappables.add(mappable_id) - unique_specs.append(spec) - tight_bbox = axis.get_tightbbox(renderer).transformed(axis.transAxes.inverted()) - base_offsets_axes = { - "left": max(0.0, -tight_bbox.x0), - "right": max(0.0, tight_bbox.x1 - 1), - "bottom": max(0.0, -tight_bbox.y0), - "top": max(0.0, tight_bbox.y1 - 1), - } - trackers_axes = {k: base_offsets_axes[k] for k in base_offsets_axes} - for spec in unique_specs: - _draw_colorbar(spec, fig, renderer, base_offsets_axes, trackers_axes) + _layout_pending_colorbars(pending_colorbars, fig_params, colorbar_params) if fig_params.fig is not None and save is not None: save_fig(fig_params.fig, path=save) @@ -1888,3 +1568,440 @@ def _draw_colorbar( if show: plt.show() return (fig_params.ax if fig_params.axs is None else fig_params.axs) if return_ax else None # shuts up ruff + + +# Render commands queued on the plotting tree, as ``(command_name, render_params)`` pairs. +_RenderCmd = tuple[ + str, + ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams | GraphRenderParams, +] + + +def _collect_render_commands(plotting_tree: OrderedDict[str, Any]) -> list[_RenderCmd]: + """Extract and validate the queued ``render_*`` commands from the plotting tree. + + Each tree key is prefixed with its execution step (e.g. ``"1_render_images"``); the prefix + is stripped and the bare command validated. Raises if an unknown command is present or if no + render command was queued at all. + """ + valid_commands = [ + "render_images", + "render_shapes", + "render_labels", + "render_points", + "render_graph", + ] + + render_cmds: list[_RenderCmd] = [] + for cmd, params in plotting_tree.items(): + # strip the step-index prefix from cmd and verify it's valid + cmd = "_".join(cmd.split("_")[1:]) + + if cmd not in valid_commands: + raise ValueError(f"Command {cmd} is not valid.") + + if "render" in cmd: + # verify that rendering commands have been called before + render_cmds.append((cmd, params)) + + if not render_cmds: + raise TypeError("Please specify what to plot using the 'render_*' functions before calling 'imshow()'.") + + return render_cmds + + +def _normalize_title(title: list[str] | str | None) -> list[str] | None: + """Normalize the ``title`` argument to a list of strings (or ``None``).""" + if title is None: + return None + if isinstance(title, str): + title = [title] + if not all(isinstance(t, str) for t in title): + raise TypeError("All titles must be strings.") + return title + + +def _resolve_coordinate_systems( + sdata: sd.SpatialData, + coordinate_systems: list[str] | str | None, + cs_was_auto: bool, + render_cmds: list[_RenderCmd], + cs_index: pd.DataFrame, + ax: list[Axes] | Axes | None, +) -> list[str]: + """Resolve, validate and filter the coordinate systems to render. + + Auto-detects all coordinate systems when ``coordinate_systems is None``, validates the + requested names, then drops systems that hold no element relevant to the queued render + commands. When axes are supplied alongside auto-detection, narrows to systems carrying an + element for *every* render command (workaround for upstream #176). + """ + coordinate_systems = list(sdata.coordinate_systems) if cs_was_auto else coordinate_systems + if isinstance(coordinate_systems, str): + coordinate_systems = [coordinate_systems] + assert coordinate_systems is not None + + for cs in coordinate_systems: + if cs not in sdata.coordinate_systems: + raise ValueError(f"Unknown coordinate system '{cs}', valid choices are: {sdata.coordinate_systems}") + + elements_to_be_rendered = _get_elements_to_be_rendered(render_cmds, cs_index, cs) + + # filter out cs without relevant elements + cmds = [cmd for cmd, _ in render_cmds] + coordinate_systems = _get_valid_cs( + sdata=sdata, + coordinate_systems=coordinate_systems, + render_images="render_images" in cmds, + render_labels="render_labels" in cmds, + render_points="render_points" in cmds, + render_shapes="render_shapes" in cmds, + elements=elements_to_be_rendered, + ) + + # When CS was auto-detected and ax is provided, keep only CS that have + # element types for ALL render commands (workaround for upstream #176). + if ax is not None and cs_was_auto: + n_ax = 1 if isinstance(ax, Axes) else len(ax) + if len(coordinate_systems) > n_ax: + required_flags = [_RENDER_CMD_TO_CS_FLAG[cmd] for cmd in cmds if cmd in _RENDER_CMD_TO_CS_FLAG] + strict_cs = [ + cs_name + for cs_name in coordinate_systems + if cs_name in cs_index.index and all(cs_index.loc[cs_name][flag] for flag in required_flags) + ] + if strict_cs: + coordinate_systems = strict_cs + + return coordinate_systems + + +def _plan_panels( + coordinate_systems: list[str], + render_cmds: list[_RenderCmd], + ax: list[Axes] | Axes | None, + cs_was_auto: bool, +) -> list[tuple[str, str | None]]: + """Determine the panel layout as ``(coordinate_system, panel_key)`` tuples. + + Panels are normally one per coordinate system, but when a render_* call passed a list of + color keys we instead lay out one panel per key within a single coordinate system + (scanpy-style ``color=[...]``). Render entries tagged with a ``panel_key`` belong to that + key's panel; untagged entries are shared across all panels. Also validates the panel count + against any user-supplied axes. + """ + panel_keys: list[str] = [] + for _cmd, _params in render_cmds: + pkey = getattr(_params, "panel_key", None) + if pkey is not None and pkey not in panel_keys: + panel_keys.append(pkey) + if panel_keys: + if len(coordinate_systems) != 1: + raise ValueError( + "A list of color keys (multi-panel plotting) requires exactly one coordinate system, " + f"but {len(coordinate_systems)} were selected: {coordinate_systems}. " + "Pass `coordinate_systems=` to choose a single one." + ) + panels: list[tuple[str, str | None]] = [(coordinate_systems[0], key) for key in panel_keys] + else: + panels = [(cs, None) for cs in coordinate_systems] + + if ax is not None: + n_ax = 1 if isinstance(ax, Axes) else len(ax) + if len(panels) != n_ax: + msg = f"Mismatch between number of matplotlib axes objects ({n_ax}) and number of panels ({len(panels)})." + if cs_was_auto: + msg += ( + " This can happen when elements have transformations to multiple " + "coordinate systems (e.g. after filter_by_coordinate_system). " + "Pass `coordinate_systems=` explicitly to select which ones to plot." + ) + raise ValueError(msg) + + return panels + + +def _build_legend_params( + legend_params: dict[str, Any] | None, + legend_fontsize: int | float | _FontSize | None, + legend_fontweight: int | _FontWeight, + legend_loc: str | None, + legend_fontoutline: int | None, + na_in_legend: bool, + colorbar: bool, + legend_title: str | None, + outline_legend_title: str | None, +) -> LegendParams: + """Build the :class:`LegendParams` bundle, applying the ``legend_params`` dict overrides. + + Keys in the ``legend_params`` dict take precedence over the matching flat ``legend_*`` + keyword arguments. + """ + if legend_params: + legend_fontsize = legend_params.get("fontsize", legend_fontsize) + legend_fontweight = legend_params.get("fontweight", legend_fontweight) + # `loc` is matplotlib.Legend's native key; `location` aligns with colorbar/scalebar. + legend_loc = legend_params.get("location", legend_params.get("loc", legend_loc)) + legend_fontoutline = legend_params.get("fontoutline", legend_fontoutline) + na_in_legend = legend_params.get("na_in_legend", na_in_legend) + + if legend_loc == "on data": + raise ValueError("legend_loc='on data' is not supported in spatialdata-plot.") + + return LegendParams( + legend_fontsize=legend_fontsize, + legend_fontweight=legend_fontweight, + legend_loc=legend_loc, + legend_fontoutline=legend_fontoutline, + na_in_legend=na_in_legend, + colorbar=colorbar, + legend_title=legend_title, + outline_legend_title=outline_legend_title, + ) + + +def _draw_colorbar( + spec: ColorbarSpec, + fig: Figure, + renderer: RendererBase, + trackers_axes: dict[str, float], + colorbar_params: dict[str, object] | None, +) -> None: + """Draw a single colorbar inset against ``spec.ax`` and update the side-offset trackers. + + ``trackers_axes`` accumulates, per side, how far out colorbars already extend so successive + colorbars on the same side stack instead of overlapping. + """ + norm = spec.mappable.norm + if isinstance(norm, LogNorm): + vmin, vmax = norm.vmin, norm.vmax + if vmin is None or vmax is None or vmin <= 0 or vmin >= vmax: + warnings.warn( + "Data contains zeros or non-positive values; colorbar suppressed for `LogNorm`. " + "Pass `colorbar=False` to silence this warning, or clip the data to positive values.", + UserWarning, + stacklevel=2, + ) + return + + base_layout = { + "location": CBAR_DEFAULT_LOCATION, + "fraction": CBAR_DEFAULT_FRACTION, + "pad": CBAR_DEFAULT_PAD, + } + layer_layout, layer_kwargs, layer_label_override = _split_colorbar_params(spec.params) + global_layout, global_kwargs, global_label_override = _split_colorbar_params(colorbar_params) + layout = {**base_layout, **layer_layout, **global_layout} + cbar_kwargs = {**layer_kwargs, **global_kwargs} + + location = cast(str, layout.get("location", base_layout["location"])) + if location not in {"left", "right", "top", "bottom"}: + location = CBAR_DEFAULT_LOCATION + default_orientation = "vertical" if location in {"right", "left"} else "horizontal" + cbar_kwargs.setdefault("orientation", default_orientation) + + fraction = float(cast(float | int, layout.get("fraction", base_layout["fraction"]))) + pad = float(cast(float | int, layout.get("pad", base_layout["pad"]))) + + vertical = location in {"left", "right"} + pad_axes = pad + trackers_axes[location] + # "left"/"bottom" grow outward in the negative direction; "right"/"top" past 1. + start = -pad_axes - fraction if location in {"left", "bottom"} else 1 + pad_axes + bbox = (float(start), 0.0, float(fraction), 1.0) if vertical else (0.0, float(start), 1.0, float(fraction)) + cax = inset_axes( + spec.ax, + width="100%", + height="100%", + loc="center", + bbox_to_anchor=bbox, + bbox_transform=spec.ax.transAxes, + borderpad=0.0, + ) + + cb = fig.colorbar(spec.mappable, cax=cax, **cbar_kwargs) + opposite = {"left": "right", "right": "left", "top": "bottom", "bottom": "top"}[location] + cbar_axis = cb.ax.yaxis if vertical else cb.ax.xaxis + cbar_axis.set_ticks_position(location) + cbar_axis.set_label_position(location) + cb.ax.tick_params(**{f"label{location}": True, f"label{opposite}": False}) + + final_label = global_label_override or layer_label_override or spec.label + if final_label: + cb.set_label(final_label) + if spec.alpha is not None: + with contextlib.suppress(Exception): + cb.solids.set_alpha(spec.alpha) + bbox_axes = cb.ax.get_tightbbox(renderer).transformed(spec.ax.transAxes.inverted()) + trackers_axes[location] = pad_axes + (bbox_axes.width if vertical else bbox_axes.height) + + +def _layout_pending_colorbars( + pending_colorbars: list[tuple[Axes, list[ColorbarSpec]]], + fig_params: FigParams, + colorbar_params: dict[str, object] | None, +) -> None: + """Second-pass colorbar layout, run once the canvas geometry is known. + + For each axis, deduplicates colorbar requests by mappable, seeds the per-side offset + trackers from the axis' tight bounding box, and draws each colorbar so they stack outward + without overlapping the axis content. + """ + if not (pending_colorbars and fig_params.fig is not None): + return + + fig = fig_params.fig + fig.canvas.draw() + renderer = fig.canvas.get_renderer() + for axis, requests in pending_colorbars: + unique_specs: list[ColorbarSpec] = [] + seen_mappables: set[int] = set() + for spec in requests: + mappable_id = id(spec.mappable) + if mappable_id in seen_mappables: + continue + seen_mappables.add(mappable_id) + unique_specs.append(spec) + tight_bbox = axis.get_tightbbox(renderer).transformed(axis.transAxes.inverted()) + trackers_axes = { + "left": max(0.0, -tight_bbox.x0), + "right": max(0.0, tight_bbox.x1 - 1), + "bottom": max(0.0, -tight_bbox.y0), + "top": max(0.0, tight_bbox.y1 - 1), + } + for spec in unique_specs: + _draw_colorbar(spec, fig, renderer, trackers_axes, colorbar_params) + + +def _finalize_panel( + ax: Axes, + panel_idx: int, + title: list[str] | None, + panel_key: str | None, + cs: str, + frameon: bool | None, +) -> None: + """Set a panel's title, equal aspect ratio and frame visibility. + + With no explicit ``title`` the panel is labelled with its color key (multi-panel color mode) + or its coordinate-system name; a single-element list applies to every panel, otherwise the + title at ``panel_idx`` is used. + """ + if title is None: + t = panel_key if panel_key is not None else cs + elif len(title) == 1: + t = title[0] + else: + try: + t = title[panel_idx] + except IndexError as e: + raise IndexError("The number of titles must match the number of panels.") from e + ax.set_title(t) + ax.set_aspect("equal") + if frameon is False: + ax.axis("off") + + +def _should_rasterize( + render_params: ImageRenderParams | LabelsRenderParams, + dpi: int | None, + figsize: tuple[float, float] | None, +) -> bool: + """Rasterize when no scale is set, or a non-``"full"`` scale is paired with a fixed canvas size.""" + scale = render_params.scale + return scale is None or (isinstance(scale, str) and scale != "full" and (dpi is not None or figsize is not None)) + + +def _maybe_set_label_colors(sdata: sd.SpatialData, render_params: LabelsRenderParams) -> None: + """Materialize a categorical palette on the table annotating a labels element, if applicable.""" + table = render_params.table_name + if table is None or render_params.col_for_color is None: + return + colors = sc.get.obs_df(sdata[table], [render_params.col_for_color]) + if isinstance(colors[render_params.col_for_color].dtype, pd.CategoricalDtype): + _maybe_set_colors( + source=sdata[table], + target=sdata[table], + key=render_params.col_for_color, + palette=render_params.palette, + ) + + +def _render_panel( + sdata: sd.SpatialData, + render_cmds: list[_RenderCmd], + cs: str, + panel_key: str | None, + panel_idx: int, + ax: Axes, + fig_params: FigParams, + legend_params_obj: LegendParams, + axis_colorbar_requests: list[ColorbarSpec] | None, + axis_channel_legend_entries: list[ChannelLegendEntry], + cs_row: pd.Series, + title: list[str] | None, + dpi: int | None, + figsize: tuple[float, float] | None, +) -> tuple[list[str], dict[str, bool]]: + """Render every applicable render command into a single panel's axes. + + Dispatches each queued render command to its ``_render_*`` function, skipping entries that + belong to a different color panel (``panel_key``). Colorbar requests and channel-legend + entries accumulate on the passed-in lists. Returns the wanted element names and a per-type + ``wants`` flag dict (keyed ``"images"``/``"labels"``/``"points"``/``"shapes"``) needed + downstream for extent computation. + """ + renderers = { + "render_images": _render_images, + "render_shapes": _render_shapes, + "render_points": _render_points, + "render_labels": _render_labels, + } + wants = dict.fromkeys(("images", "labels", "points", "shapes"), False) + wanted_elements: list[str] = [] + + for cmd, params in render_cmds: + # Skip render entries that belong to a different color panel. Entries with no + # `panel_key` (None) are shared and drawn into every panel (e.g. a background image). + cmd_panel_key = getattr(params, "panel_key", None) + if panel_key is not None and cmd_panel_key is not None and cmd_panel_key != panel_key: + continue + # We create a copy here as the wanted elements can change from one cs to another. + params_copy = deepcopy(params) + + if cmd == "render_graph": + graph_element = params_copy.element + if graph_element in sdata and cs in get_transformation(sdata[graph_element], get_all=True): + _render_graph( + sdata=sdata, + render_params=params_copy, + coordinate_system=cs, + ax=ax, + legend_params=legend_params_obj, + colorbar_requests=axis_colorbar_requests, + ) + elif cmd in renderers and cs_row[_RENDER_CMD_TO_CS_FLAG[cmd]]: + element_type = cmd.removeprefix("render_") + wanted_elements, wanted_on_cs, wants[element_type] = _get_wanted_render_elements( + sdata, wanted_elements, params_copy, cs, element_type + ) + if wanted_on_cs: + kwargs: dict[str, Any] = { + "sdata": sdata, + "render_params": params_copy, + "coordinate_system": cs, + "ax": ax, + "fig_params": fig_params, + "legend_params": legend_params_obj, + "colorbar_requests": axis_colorbar_requests, + } + if cmd == "render_images": + kwargs["channel_legend_entries"] = axis_channel_legend_entries + if cmd in {"render_images", "render_labels"}: + kwargs["rasterize"] = _should_rasterize(params_copy, dpi, figsize) + if cmd == "render_labels": + _maybe_set_label_colors(sdata, params_copy) + renderers[cmd](**kwargs) + + _finalize_panel(ax, panel_idx, title, panel_key, cs, fig_params.frameon) + + return wanted_elements, wants