"""Classes for redirecting system input and output.

Note: The first three chapters use input() and print() instead of function
arguments and return values. As a result, we need to "redirect" the source
of input() and print() to strings. The RedirectInput and RedirectOutput
classes are used for this purpose. The run_module() function uses a "with"
statement to run a module. When the with block is entered, the __enter__
functions are called. When the with block is exited, the __exit__ functions
all called. These functions replace (and restore) system input and output,
even if an error is raised while running the module that is being tested.
"""

import importlib
import io
import sys


class RedirectInput:
    """Redirect system input to the provided string."""

    def __init__(self, input_str):
        # Create buffer for input.
        self.buffer = io.StringIO(input_str)

    def __enter__(self):
        # Replace stdin with self.
        sys.stdin = self.buffer

    def __exit__(self, exc_type, exc_value, traceback):
        # Restore original stdin.
        sys.stdin = sys.__stdin__


class RedirectOutput:
    """Redirect system output to a new string buffer."""

    def __init__(self):
        # Create buffer for output.
        self.buffer = io.StringIO()

    def __enter__(self):
        # Replace stdout with self.
        sys.stdout = self.buffer

    def __exit__(self, exc_type, exc_value, traceback):
        # Restore original stdout.
        sys.stdout = sys.__stdout__


def run_module(name, input_str=""):
    """Run a module while redirecting system I/O.

    Args:
        name (str): The name of the module to run.
        input_str (str): The contents of system input.

    Returns:
        str: The contents of system output.
    """
    module_input = RedirectInput(input_str)
    module_output = RedirectOutput()
    # Use a with statement in case an error occurs.
    with module_input, module_output:
        module = sys.modules.get(name)
        if module:
            # Module already imported; run again.
            importlib.reload(module)
        else:
            # Run the module for the first time.
            importlib.import_module(name)
    # Get the string value of the StringIO buffer.
    return module_output.buffer.getvalue()
