Source code for craft_cli.printer

# Copyright 2023 Canonical Ltd.
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""The output (for different destinations) handler and helper functions."""

from __future__ import annotations

import itertools
import math
import queue
import shutil
import threading
import time
from dataclasses import dataclass, field
from datetime import datetime
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Callable, TextIO

if TYPE_CHECKING:
    import pathlib

# the char used to draw the progress bar ('FULL BLOCK')
_PROGRESS_BAR_SYMBOL = "█"

# seconds before putting the spinner to work
_SPINNER_THRESHOLD = 2

# seconds between each spinner char
_SPINNER_DELAY = 0.1

# set to true when running *application* tests so some behaviours change (see
# craft_cli/pytest_plugin.py )
TESTMODE = False


@dataclass
class _MessageInfo:
    """Comprehensive information for a message that may go to screen and log."""

    stream: TextIO | None
    text: str
    ephemeral: bool = False
    bar_progress: int | float | None = None
    bar_total: int | float | None = None
    use_timestamp: bool = False
    end_line: bool = False
    created_at: datetime = field(default_factory=datetime.now, compare=False)
    terminal_prefix: str = ""


@lru_cache
def _stream_is_terminal(stream: TextIO | None) -> bool:
    is_a_terminal = getattr(stream, "isatty", lambda: False)()
    return is_a_terminal and _get_terminal_width() > 0


def _get_terminal_width() -> int:
    """Return the number of columns of the terminal."""
    return shutil.get_terminal_size().columns


def _format_term_line(prefix: str, text: str, spintext: str, *, ephemeral: bool) -> str:
    """Format a line to print to the terminal."""
    # fill with spaces until the very end, on one hand to clear a possible previous message,
    # but also to always have the cursor at the very end
    width = _get_terminal_width()
    usable = width - len(spintext) - 1  # the 1 is the cursor itself
    if len(text) > usable:
        if ephemeral:
            text = text[: usable - 1] + "…"
        elif spintext:
            # we need to rewrite the message with the spintext, use only the last line for
            # multiline messages, and ensure (again) that the last real line fits
            remaining_for_last_line = len(text) % width
            text = text[-remaining_for_last_line:]
            if len(text) > usable:
                text = text[: usable - 1] + "…"
    cleaner = " " * (usable - len(text) % width)

    return prefix + text + spintext + cleaner


class _Spinner(threading.Thread):
    """A supervisor thread that will repeat long-standing messages with a spinner besides it.

    This will be a long-lived single thread that will supervise each message received
    through the `supervise` method, and when it stays too long, the printer's `spin`
    will be called with that message and a text to "draw" a spinner, including the elapsed
    time.

    The timing related part of the code uses two constants: _SPINNER_THRESHOLD is how
    many seconds before activating the spinner for the message, and _SPINNER_DELAY is
    the time between `spin` calls.

    When a new message arrives (or None, to indicate that there is nothing to supervise) and
    the previous message was "being spinned", a last `spin` call will be done to clean
    the spinner.
    """

    def __init__(self, printer: Printer) -> None:
        super().__init__()
        # special flag used to stop the spinner thread
        self.stop_flag = object()

        # daemon mode, so if the app crashes this thread does not holds everything
        self.daemon = True

        # communication from the printer
        self.queue: queue.Queue[Any] = queue.Queue()

        # hold the printer, to make it spin
        self.printer = printer

        # a lock to wait the spinner to stop spinning
        self.lock = threading.Lock()

        # Keep the message under supervision available for examination.
        self._under_supervision: _MessageInfo | None = None

    def run(self) -> None:
        prv_msg = None
        t_init = time.time()
        while prv_msg is not self.stop_flag:
            try:
                new_msg = self.queue.get(timeout=_SPINNER_THRESHOLD)
            except queue.Empty:
                # waited too much, start to show a spinner (if have a previous message) until
                # we have further info
                if prv_msg is None or prv_msg.end_line:
                    continue
                spinchars = itertools.cycle("-\\|/")
                with self.lock:
                    while True:
                        t_delta = time.time() - t_init
                        spintext = f" {next(spinchars)} ({t_delta:.1f}s)"
                        self.printer.spin(prv_msg, spintext)
                        try:
                            new_msg = self.queue.get(timeout=_SPINNER_DELAY)
                        except queue.Empty:
                            # still nothing! keep going
                            continue
                        # got a new message: clean the spinner and exit from the spinning state
                        self.printer.spin(prv_msg, " ")
                        break

            prv_msg = new_msg
            t_init = time.time()

    def supervise(self, message: _MessageInfo | None) -> None:
        """Supervise a message to spin it if it remains too long."""
        # Don't bother the spinner if we're repeating the same message
        if message == self._under_supervision:
            return

        self._under_supervision = message
        self.queue.put(message)
        # (maybe) wait for the spinner to exit spinning state (which does some cleaning)
        self.lock.acquire()
        self.lock.release()

    def stop(self) -> None:
        """Stop self."""
        self.queue.put(self.stop_flag)
        self.join()


[docs] class Printer: """Handle writing the different messages to the different outputs (out, err and log). If TESTMODE is True, this class changes its behaviour: the spinner is never started, so there is no thread polluting messages when running tests if they take too long to run. """ def __init__(self, log_filepath: pathlib.Path) -> None: self.stopped = False # holder of the previous message self.prv_msg: _MessageInfo | None = None # open the log file (will be closed explicitly later) self.log = log_filepath.open("at", encoding="utf8") # keep account of output terminal streams with unfinished lines self.unfinished_stream: TextIO | None = None self.terminal_prefix = "" self.secrets: list[str] = [] # run the spinner supervisor self.spinner = _Spinner(self) if not TESTMODE: self.spinner.start()
[docs] def set_terminal_prefix(self, prefix: str) -> None: """Set the string to be prepended to every message shown to the terminal.""" self.terminal_prefix = prefix
def _get_prefixed_message_text(self, message: _MessageInfo) -> str: """Get the message's text with the proper terminal prefix, if any.""" text = message.text prefix = message.terminal_prefix # Don't repeat text: can happen due to the spinner. if prefix and text != prefix: separator = ":: " # Don't duplicate the separator, which can come from multiple different # sources. if text.startswith(separator): separator = "" text = f"{prefix} {separator}{text}" return text def _write_line_terminal(self, message: _MessageInfo, *, spintext: str = "") -> None: """Write a simple line message to the screen.""" # prepare the text with (maybe) the timestamp text = self._get_prefixed_message_text(message) if message.use_timestamp: timestamp_str = message.created_at.isoformat(sep=" ", timespec="milliseconds") text = f"{timestamp_str} {text}" if spintext: # forced to overwrite the previous message to present the spinner maybe_cr = "\r" elif self.prv_msg is None or self.prv_msg.end_line: # first message, or previous message completed the line: start clean maybe_cr = "" elif self.prv_msg.ephemeral: # the last one was ephemeral, overwrite it maybe_cr = "\r" if self.prv_msg.stream != message.stream: # If the last message's stream is different from this new one, # send the carriage return to the original stream only. print(maybe_cr, flush=True, file=self.prv_msg.stream, end="") maybe_cr = "" else: # complete the previous line, leaving that message ok maybe_cr = "" print(flush=True, file=self.prv_msg.stream) # We don't need to rewrite the same ephemeral message repeatedly. should_overwrite = spintext or message.end_line or not message.ephemeral if should_overwrite or message != self.prv_msg: line = _format_term_line(maybe_cr, text, spintext, ephemeral=message.ephemeral) print(line, end="", flush=True, file=message.stream) if message.end_line: # finish the just shown line, as we need a clean terminal for some external thing print(flush=True, file=message.stream) self.unfinished_stream = None else: self.unfinished_stream = message.stream def _write_line_captured(self, message: _MessageInfo) -> None: """Write a simple line message to a captured output.""" # prepare the text with (maybe) the timestamp if message.use_timestamp: timestamp_str = message.created_at.isoformat(sep=" ", timespec="milliseconds") text = timestamp_str + " " + message.text else: text = message.text print(text, file=message.stream) def _write_bar_terminal(self, message: _MessageInfo) -> None: """Write a progress bar to the screen.""" # prepare the text with (maybe) the timestamp if message.use_timestamp: timestamp_str = message.created_at.isoformat(sep=" ", timespec="milliseconds") text = timestamp_str + " " + message.text else: text = message.text if self.prv_msg is None or self.prv_msg.end_line: # first message, or previous message completed the line: start clean maybe_cr = "" elif self.prv_msg.ephemeral: # the last one was ephemeral, overwrite it maybe_cr = "\r" else: # complete the previous line, leaving that message ok maybe_cr = "" print(flush=True, file=self.prv_msg.stream) if message.bar_progress is None or message.bar_total is None: # pragma: no cover # Should not happen as the caller checks the message raise ValueError("Tried to write a bar message with invalid attributes") numerical_progress = f"{message.bar_progress}/{message.bar_total}" bar_percentage = min(message.bar_progress / message.bar_total, 1) # terminal size minus the text and numerical progress, and 5 (the cursor at the end, # two spaces before and after the bar, and two surrounding brackets) terminal_width = _get_terminal_width() bar_width = terminal_width - len(text) - len(numerical_progress) - 5 # only show the bar with progress if there is enough space, otherwise just the # message (truncated, if needed) if bar_width > 0: completed_width = math.floor(bar_width * min(bar_percentage, 100)) completed_bar = _PROGRESS_BAR_SYMBOL * completed_width empty_bar = " " * (bar_width - completed_width) line = f"{maybe_cr}{text} [{completed_bar}{empty_bar}] {numerical_progress}" else: text = text[: terminal_width - 1] # space for cursor line = f"{maybe_cr}{text}" print(line, end="", flush=True, file=message.stream) self.unfinished_stream = message.stream def _write_bar_captured(self, message: _MessageInfo) -> None: """Do not write any progress bar to the captured output.""" def _show(self, msg: _MessageInfo) -> None: """Show the composed message.""" # show the message in one way or the other only if there is a stream if msg.stream is None: return # the writing functions depend on the final output: if the stream is captured or it's # a real terminal write_line: Callable[[_MessageInfo], None] if _stream_is_terminal(msg.stream): write_line = self._write_line_terminal write_bar = self._write_bar_terminal else: write_line = self._write_line_captured write_bar = self._write_bar_captured if msg.bar_progress is None: # regular message, send it to the spinner and write it self.spinner.supervise(msg) write_line(msg) else: # progress bar, send None to the spinner (as it's not a "spinnable" message) # and write it self.spinner.supervise(None) write_bar(msg) self.prv_msg = msg def _log(self, message: _MessageInfo) -> None: """Write the line message to the log file.""" # prepare the text with (maybe) the timestamp timestamp_str = message.created_at.isoformat(sep=" ", timespec="milliseconds") self.log.write(f"{timestamp_str} {message.text}\n") # Flush the file: protect a bit in case of crashes, and multiprocess-based # parallelism. self.log.flush()
[docs] def spin(self, message: _MessageInfo, spintext: str) -> None: """Write a line message including a spin text, only to a terminal.""" if _stream_is_terminal(message.stream): self._write_line_terminal(message, spintext=spintext)
[docs] def show( # noqa: PLR0913 (too many parameters) self, stream: TextIO | None, text: str, *, ephemeral: bool = False, use_timestamp: bool = False, end_line: bool = False, avoid_logging: bool = False, ) -> None: """Show a text to the given stream if not stopped.""" if self.stopped: return text = self._apply_secrets(text) msg = _MessageInfo( stream=stream, text=text.rstrip(), ephemeral=ephemeral, use_timestamp=use_timestamp, end_line=end_line, terminal_prefix=self._apply_secrets(self.terminal_prefix), ) self._show(msg) if not avoid_logging: self._log(msg)
[docs] def progress_bar( # noqa: PLR0913 self, stream: TextIO | None, text: str, *, progress: float, total: float, use_timestamp: bool, ) -> None: """Show a progress bar to the given stream.""" text = self._apply_secrets(text) msg = _MessageInfo( stream=stream, text=text.rstrip(), bar_progress=progress, bar_total=total, ephemeral=True, # so it gets eventually overwritten by other message use_timestamp=use_timestamp, ) self._show(msg)
[docs] def stop(self) -> None: """Stop the printing infrastructure. In detail: - stop the spinner - add a new line to the screen (if needed) - close the log file """ if not TESTMODE: self.spinner.stop() if self.unfinished_stream is not None: # With unfinished_stream set, the prv_msg object is valid. if self.prv_msg is not None and self.prv_msg.ephemeral: # If the last printed message is of 'ephemeral' type, the stop # request must clean and reset the line. cleaner = " " * (_get_terminal_width() - 1) line = "\r" + cleaner + "\r" print(line, end="", flush=True, file=self.prv_msg.stream) else: # The last printed message is permanent. Leave the cursor on # the next clean line. print(flush=True, file=self.unfinished_stream) self.log.close() self.stopped = True
[docs] def set_secrets(self, secrets: list[str]) -> None: """Set the list of strings that should be masked out in all outputs.""" # Keep a copy, to protect against clients modifying the list on accident. self.secrets = secrets.copy()
def _apply_secrets(self, text: str) -> str: for secret in self.secrets: text = text.replace(secret, "*****") return text