Source code for kfp.dsl._component

# Copyright 2018 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from ._metadata import _extract_component_metadata
from ._pipeline_param import PipelineParam
from .types import check_types, InconsistentTypeException
from ._ops_group import Graph
import kfp

[docs]def python_component(name, description=None, base_image=None, target_component_file: str = None): """Decorator for Python component functions. This decorator adds the metadata to the function object itself. Args: name: Human-readable name of the component description: Optional. Description of the component base_image: Optional. Docker container image to use as the base of the component. Needs to have Python 3.5+ installed. target_component_file: Optional. Local file to store the component definition. The file can then be used for sharing. Returns: The same function (with some metadata fields set). Usage: ```python @dsl.python_component( name='my awesome component', description='Come, Let's play', base_image='tensorflow/tensorflow:1.11.0-py3', ) def my_component(a: str, b: int) -> str: ... ``` """ def _python_component(func): func._component_human_name = name if description: func._component_description = description if base_image: func._component_base_image = base_image if target_component_file: func._component_target_component_file = target_component_file return func return _python_component
[docs]def component(func): """Decorator for component functions that returns a ContainerOp. This is useful to enable type checking in the DSL compiler Usage: ```python @dsl.component def foobar(model: TFModel(), step: MLStep()): return dsl.ContainerOp() """ from functools import wraps @wraps(func) def _component(*args, **kargs): component_meta = _extract_component_metadata(func) if kfp.TYPE_CHECK: arg_index = 0 for arg in args: if isinstance(arg, PipelineParam) and not check_types(arg.param_type.to_dict_or_str(), component_meta.inputs[arg_index].param_type.to_dict_or_str()): raise InconsistentTypeException('Component "' + + '" is expecting ' + component_meta.inputs[arg_index].name + ' to be type(' + component_meta.inputs[arg_index].param_type.serialize() + '), but the passed argument is type(' + arg.param_type.serialize() + ')') arg_index += 1 if kargs is not None: for key in kargs: if isinstance(kargs[key], PipelineParam): for input_spec in component_meta.inputs: if == key and not check_types(kargs[key].param_type.to_dict_or_str(), input_spec.param_type.to_dict_or_str()): raise InconsistentTypeException('Component "' + + '" is expecting ' + + ' to be type(' + input_spec.param_type.serialize() + '), but the passed argument is type(' + kargs[key].param_type.serialize() + ')') container_op = func(*args, **kargs) container_op._set_metadata(component_meta) return container_op return _component
#TODO: combine the component and graph_component decorators into one
[docs]def graph_component(func): """Decorator for graph component functions. This decorator returns an ops_group. Usage: ```python import kfp.dsl as dsl @dsl.graph_component def flip_component(flip_result): print_flip = PrintOp(flip_result) flipA = FlipCoinOp().after(print_flip) with dsl.Condition(flipA.output == 'heads'): flip_component(flipA.output) return {'flip_result': flipA.output} """ from functools import wraps @wraps(func) def _graph_component(*args, **kargs): graph_ops_group = Graph(func.__name__) graph_ops_group.inputs = list(args) + list(kargs.values()) for input in graph_ops_group.inputs: if not isinstance(input, PipelineParam): raise ValueError('arguments to ' + func.__name__ + ' should be PipelineParams.') # Entering the Graph Context with graph_ops_group: # Call the function if not graph_ops_group.recursive_ref: func(*args, **kargs) return graph_ops_group return _graph_component