# -*- Mode:Python; indent-tabs-mode:nil; tab-width:4 -*-
#
# Copyright 2015-2024 Canonical Ltd.
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Base classes for source type handling."""
import abc
import logging
import os
import shutil
import subprocess
from collections.abc import Sequence
from pathlib import Path
from typing import Any, ClassVar
import pydantic
import requests
from overrides import overrides
from craft_parts.dirs import ProjectDirs
from craft_parts.utils import os_utils, url_utils
from . import errors
from .cache import FileCache
from .checksum import verify_checksum
logger = logging.getLogger(__name__)
[docs]
def get_model_config(
json_schema_extra: dict[str, Any] | None = None,
) -> pydantic.ConfigDict:
"""Get a config for a model with minor changes from the default."""
return pydantic.ConfigDict(
alias_generator=lambda s: s.replace("_", "-"),
json_schema_extra=json_schema_extra,
extra="forbid",
)
[docs]
class BaseSourceModel(pydantic.BaseModel, frozen=True): # type: ignore[misc]
"""A base model for source types."""
model_config = get_model_config()
source_type: str
source: str
[docs]
class BaseFileSourceModel(BaseSourceModel, frozen=True):
"""A base model for file-based source types."""
source_checksum: str | None = None
[docs]
class SourceHandler(abc.ABC):
"""The base class for source type handlers.
Methods :meth:`check_if_outdated` and :meth:`update_source` can be
overridden by subclasses to implement verification and update of
source files.
"""
source_model: ClassVar[type[BaseSourceModel]]
def __init__(
self,
source: str,
part_src_dir: Path,
*,
cache_dir: Path,
project_dirs: ProjectDirs,
ignore_patterns: list[str] | None = None,
**kwargs: Any,
) -> None:
if not ignore_patterns:
ignore_patterns = []
invalid_options = []
model_params = {key.replace("_", "-"): value for key, value in kwargs.items()}
model_params["source"] = source
properties = self.source_model.model_json_schema()["properties"]
for option, value in kwargs.items():
option_alias = option.replace("_", "-")
if option_alias not in properties:
if not value:
del model_params[option_alias]
else:
invalid_options.append(option_alias)
if len(invalid_options) > 1:
raise errors.InvalidSourceOptions(
source_type=properties["source-type"]["default"],
options=invalid_options,
)
if len(invalid_options) == 1:
raise errors.InvalidSourceOption(
source_type=properties["source-type"]["default"],
option=invalid_options[0],
)
self._data = self.source_model.model_validate(model_params)
self.source = source
self.part_src_dir = part_src_dir
self._cache_dir = cache_dir
self.source_details: dict[str, str | None] | None = None
self._dirs = project_dirs
self._checked = False
self._ignore_patterns = ignore_patterns.copy()
self.outdated_files: list[str] | None = None
self.outdated_dirs: list[str] | None = None
def __getattr__(self, name: str) -> Any: # noqa: ANN401 (this must be dynamic)
return getattr(self._data, name)
[docs]
@abc.abstractmethod
def pull(self) -> None:
"""Retrieve the source file."""
[docs]
def check_if_outdated(
self, target: str, *, ignore_files: list[str] | None = None # noqa: ARG002
) -> bool:
"""Check if pulled sources have changed since target was created.
:param target: Path to target file.
:param ignore_files: Files excluded from verification.
:return: Whether the sources are outdated.
:raise errors.SourceUpdateUnsupported: If the source handler can't check if
files are outdated.
"""
raise errors.SourceUpdateUnsupported(self.__class__.__name__)
[docs]
def get_outdated_files(self) -> tuple[list[str], list[str]]:
"""Obtain lists of outdated files and directories.
:return: The lists of outdated files and directories.
:raise errors.SourceUpdateUnsupported: If the source handler can't check if
files are outdated.
"""
raise errors.SourceUpdateUnsupported(self.__class__.__name__)
[docs]
def update(self) -> None:
"""Update pulled source.
:raise errors.SourceUpdateUnsupported: If the source can't update its files.
"""
raise errors.SourceUpdateUnsupported(self.__class__.__name__)
@classmethod
def _run(cls, command: list[str], **kwargs: Any) -> None:
try:
os_utils.process_run(command, logger.debug, **kwargs)
except subprocess.CalledProcessError as err:
raise errors.PullError(command=command, exit_code=err.returncode) from err
@classmethod
def _run_output(cls, command: Sequence) -> str:
try:
return subprocess.check_output(command, text=True).strip()
except subprocess.CalledProcessError as err:
raise errors.PullError(command=command, exit_code=err.returncode) from err
[docs]
class FileSourceHandler(SourceHandler):
"""Base class for file source types."""
# pylint: disable=too-many-arguments
def __init__(
self,
source: str,
part_src_dir: Path,
*,
cache_dir: Path,
project_dirs: ProjectDirs,
source_checksum: str | None = None,
command: str | None = None,
ignore_patterns: list[str] | None = None,
**kwargs: Any,
) -> None:
super().__init__(
source,
part_src_dir,
cache_dir=cache_dir,
source_checksum=source_checksum,
command=command,
project_dirs=project_dirs,
ignore_patterns=ignore_patterns,
**kwargs,
)
self._file = Path()
# pylint: enable=too-many-arguments
[docs]
@abc.abstractmethod
def provision(
self,
dst: Path,
keep: bool = False, # noqa: FBT001, FBT002
src: Path | None = None,
) -> None:
"""Process the source file to extract its payload."""
[docs]
@overrides
def pull(self) -> None:
"""Retrieve this source from its origin."""
source_file = None
is_source_url = url_utils.is_url(self.source)
# First check if it is a url and download and if not
# it is probably locally referenced.
if is_source_url:
source_file = self.download()
else:
basename = os.path.basename(self.source)
source_file = Path(self.part_src_dir, basename)
# We make this copy as the provisioning logic can delete
# this file and we don't want that.
try:
shutil.copy2(self.source, source_file)
except FileNotFoundError as err:
raise errors.SourceNotFound(self.source) from err
# Verify before provisioning
if self.source_checksum:
verify_checksum(self.source_checksum, source_file)
self.provision(self.part_src_dir, src=source_file)
[docs]
def download(self, filepath: Path | None = None) -> Path:
"""Download the URL from a remote location.
:param filepath: the destination file to download to.
"""
if filepath is None:
self._file = Path(self.part_src_dir, os.path.basename(self.source))
else:
self._file = filepath
# check if we already have the source file cached
file_cache = FileCache(self._cache_dir)
if self.source_checksum:
cache_file = file_cache.get(key=self.source_checksum)
if cache_file:
# We make this copy as the provisioning logic can delete
# this file and we don't want that.
shutil.copy2(cache_file, self._file)
return self._file
# if not we download and store
if url_utils.get_url_scheme(self.source) == "ftp":
raise NotImplementedError("ftp download not implemented")
try:
request = requests.get(
self.source, stream=True, allow_redirects=True, timeout=3600
)
request.raise_for_status()
except requests.exceptions.HTTPError as err:
if err.response.status_code == requests.codes.not_found:
raise errors.SourceNotFound(source=self.source) from err
raise errors.HttpRequestError(
status_code=err.response.status_code,
reason=err.response.reason,
source=self.source,
) from err
except requests.exceptions.RequestException as err:
raise errors.NetworkRequestError(
message=f"network request failed (request={err.request!r}, "
f"response={err.response!r})",
source=self.source,
) from err
url_utils.download_request(request, str(self._file))
# if source_checksum is defined cache the file for future reuse
if self.source_checksum:
verify_checksum(self.source_checksum, self._file)
file_cache.cache(filename=str(self._file), key=self.source_checksum)
return self._file