55import traceback
66import warnings
77from contextlib import contextmanager
8- from typing import Union , TypeVar , Sequence , Any , Literal
8+ from typing import Union , TypeVar , Sequence , Any , Literal , Generic
99
1010from dataclasses import dataclass
1111from typing import Tuple , Callable , List
2828from ..backend .xops import ExtraOperator
2929
3030
31- class Tensor :
31+ T = TypeVar ('T' ) # data type of tensors
32+
33+ class Tensor (Generic [T ]):
3234 """
3335 Abstract base class to represent structured data of one data type.
3436 This class replaces the native tensor classes `numpy.ndarray`, `torch.Tensor`, `tensorflow.Tensor` or `jax.numpy.ndarray` as the main data container in Φ-ML.
@@ -470,13 +472,13 @@ def _getitem(self, selection: dict) -> 'Tensor':
470472 def __setitem__ (self , key , value ):
471473 raise SyntaxError ("Tensors are not editable to preserve the autodiff chain. This feature might be added in the future. To update part of a tensor, use math.where() or math.scatter()" )
472474
473- def __unstack__ (self , dims : Tuple [str , ...]) -> Tuple ['Tensor' , ...]: # from phiml.math.magic.Sliceable
475+ def __unstack__ (self , dims : Tuple [str , ...]) -> Tuple ['Tensor[T] ' , ...]: # from phiml.math.magic.Sliceable
474476 if len (dims ) == 1 :
475477 return self ._unstack (dims [0 ])
476478 else :
477479 return NotImplemented
478480
479- def _unstack (self , dim : str ):
481+ def _unstack (self , dim : str ) -> Tuple [ 'Tensor[T]' , ...] :
480482 """
481483 Splits this tensor along the specified dimension.
482484 The returned tensors have the same dimensions as this tensor save the unstacked dimension.
@@ -568,12 +570,12 @@ def dimension(self, name: Union[str, Shape]) -> 'TensorDim':
568570 else :
569571 raise ValueError (name )
570572
571- def pack (self , dims , packed_dim ):
573+ def pack (self , dims , packed_dim ) -> 'Tensor[T]' :
572574 """ See `pack_dims()` """
573575 from ._ops import pack_dims
574576 return pack_dims (self , dims , packed_dim )
575577
576- def unpack (self , dim , unpacked_dims ):
578+ def unpack (self , dim , unpacked_dims ) -> 'Tensor[T]' :
577579 """ See `unpack_dim()` """
578580 from ._ops import unpack_dim
579581 return unpack_dim (self , dim , unpacked_dims )
@@ -584,15 +586,15 @@ def T(self):
584586 # return self._with_shape_replaced(self.shape.transposed())
585587
586588 @property
587- def Ti (self ):
589+ def Ti (self ) -> 'Tensor[T]' :
588590 return self ._with_shape_replaced (self .shape .transpose (INSTANCE_DIM ))
589591
590592 @property
591- def Tc (self ):
593+ def Tc (self ) -> 'Tensor[T]' :
592594 return self ._with_shape_replaced (self .shape .transpose (CHANNEL_DIM ))
593595
594596 @property
595- def Ts (self ):
597+ def Ts (self ) -> 'Tensor[T]' :
596598 return self ._with_shape_replaced (self .shape .transpose (SPATIAL_DIM ))
597599
598600 def map (self , function : Callable , dims = shape_ , range = range , unwrap_scalars = True , ** kwargs ):
@@ -676,7 +678,7 @@ def __mod__(self, other):
676678 def __rmod__ (self , other ):
677679 return self ._op2 (other , operator .mod , True )
678680
679- def __eq__ (self , other ) -> 'Tensor' :
681+ def __eq__ (self , other ) -> 'Tensor[bool] ' :
680682 if self is other :
681683 return expand (True , self .shape )
682684 if _EQUALITY_REDUCE [- 1 ]['type' ] == 'ref' :
@@ -693,7 +695,7 @@ def __eq__(self, other) -> 'Tensor':
693695 else :
694696 return wrap (False )
695697
696- def __ne__ (self , other ) -> 'Tensor' :
698+ def __ne__ (self , other ) -> 'Tensor[bool] ' :
697699 if _EQUALITY_REDUCE [- 1 ]['type' ] == 'ref' :
698700 return wrap (self is not other )
699701 elif _EQUALITY_REDUCE [- 1 ]['type' ] == 'shape_and_value' :
@@ -708,49 +710,49 @@ def __ne__(self, other) -> 'Tensor':
708710 else :
709711 return wrap (True )
710712
711- def __lt__ (self , other ):
713+ def __lt__ (self , other ) -> 'Tensor[bool]' :
712714 return self ._op2 (other , operator .gt , True )
713715
714- def __le__ (self , other ):
716+ def __le__ (self , other ) -> 'Tensor[bool]' :
715717 return self ._op2 (other , operator .ge , True )
716718
717- def __gt__ (self , other ):
719+ def __gt__ (self , other ) -> 'Tensor[bool]' :
718720 return self ._op2 (other , operator .gt , False )
719721
720- def __ge__ (self , other ):
722+ def __ge__ (self , other ) -> 'Tensor[bool]' :
721723 return self ._op2 (other , operator .ge , False )
722724
723- def __lshift__ (self , other ):
725+ def __lshift__ (self , other ) -> 'Tensor[T]' :
724726 return self ._op2 (other , operator .lshift , False )
725727
726- def __rlshift__ (self , other ):
728+ def __rlshift__ (self , other ) -> 'Tensor[T]' :
727729 return self ._op2 (other , operator .lshift , True )
728730
729- def __rshift__ (self , other ):
731+ def __rshift__ (self , other ) -> 'Tensor[T]' :
730732 return self ._op2 (other , operator .rshift , False )
731733
732- def __rrshift__ (self , other ):
734+ def __rrshift__ (self , other ) -> 'Tensor[T]' :
733735 return self ._op2 (other , operator .rshift , True )
734736
735- def __abs__ (self ):
737+ def __abs__ (self ) -> 'Tensor[T]' :
736738 return self ._op1 (lambda t : choose_backend (t ).abs (t ))
737739
738- def __round__ (self , n = None ):
740+ def __round__ (self , n = None ) -> 'Tensor[int]' :
739741 return self ._op1 (lambda t : choose_backend (t ).round (t ))
740742
741- def __copy__ (self ):
743+ def __copy__ (self ) -> 'Tensor[T]' :
742744 return self ._op1 (lambda t : choose_backend (t ).copy (t , only_mutable = True ))
743745
744- def __deepcopy__ (self , memodict = {}):
746+ def __deepcopy__ (self , memodict = {}) -> 'Tensor[T]' :
745747 return self ._op1 (lambda t : choose_backend (t ).copy (t , only_mutable = False ))
746748
747- def __neg__ (self ) -> 'Tensor' :
749+ def __neg__ (self ) -> 'Tensor[T] ' :
748750 return self ._op1 (operator .neg )
749751
750- def __invert__ (self ) -> 'Tensor' :
752+ def __invert__ (self ) -> 'Tensor[T] ' :
751753 return self ._op1 (lambda t : choose_backend (t ).invert (t ))
752754
753- def __reversed__ (self ):
755+ def __reversed__ (self ) -> 'Tensor[T]' :
754756 assert self .shape .channel .rank == 1
755757 return self [::- 1 ]
756758
@@ -763,11 +765,11 @@ def __iter__(self):
763765 native = self .native ([self .shape ])
764766 return iter (native )
765767
766- def item (self ):
768+ def item (self ) -> T :
767769 assert self .shape .volume == 1 , f"Tensor.item() is only available for single-element Tensors but got { self .shape } "
768770 return next (iter (self ))
769771
770- def __matmul__ (self , other ):
772+ def __matmul__ (self , other ) -> 'Tensor[bool]' :
771773 from ._ops import dot
772774 assert isinstance (other , Tensor ), f"Matmul '@' requires two Tensor arguments but got { type (other )} "
773775 if not self .shape .dual_rank and self .shape .channel_rank :
@@ -953,7 +955,7 @@ def __init__(self, obj, stack_dim: Shape):
953955 warnings .warn (f"Empty stack_dim for Layout with value { obj } " )
954956
955957 @staticmethod
956- def _recursive_get_shapes (obj , s : Shape ) -> Tuple [Shape ]:
958+ def _recursive_get_shapes (obj , s : Shape ) -> Tuple [Shape , ... ]:
957959 if not s :
958960 return shape (obj , allow_unshaped = True ),
959961 elif isinstance (obj , (tuple , list )):
@@ -1670,10 +1672,10 @@ def _simplify(self):
16701672 return self
16711673
16721674
1673- def tensor (data ,
1675+ def tensor (data : Union [ Sequence [ T ], T ] ,
16741676 * shape : Union [Shape , str , list ],
16751677 convert : bool = True ,
1676- default_list_dim = channel ('vector' )) -> Tensor : # TODO assume convert_unsupported, add convert_external=False for constants
1678+ default_list_dim = channel ('vector' )) -> Tensor [ T ] : # TODO assume convert_unsupported, add convert_external=False for constants
16771679 """
16781680 Create a Tensor from the specified `data`.
16791681 If `convert=True`, converts `data` to the preferred format of the default backend.
@@ -1814,12 +1816,12 @@ def tensor(data,
18141816 raise ValueError (f"{ type (data )} is not supported. Only (Tensor, tuple, list, np.ndarray, native tensors) are allowed.\n Current backends: { BACKENDS } " )
18151817
18161818
1817- def wrap (data , * shape : Union [Shape , str , list ], default_list_dim = channel ('vector' )) -> Tensor :
1819+ def wrap (data : Union [ Sequence [ T ], T ], * shape : Union [Shape , str , list ], default_list_dim = channel ('vector' )) -> Tensor [ T ] :
18181820 """ Short for `phiml.math.tensor()` with `convert=False`. """
18191821 return tensor (data , * shape , convert = False , default_list_dim = default_list_dim )
18201822
18211823
1822- def layout (objects , * shape : Union [Shape , str ]) -> Tensor :
1824+ def layout (objects : Union [ Sequence [ T ], T ], * shape : Union [Shape , str ]) -> Tensor [ T ] :
18231825 """
18241826 Wraps a Python tree in a `Tensor`, allowing elements to be accessed via dimensions.
18251827 A python tree is a structure of nested `tuple`, `list`, `dict` and *leaf* objects where leaves can be any Python object.
0 commit comments