When operations on a paremetrized tensor are performed, the supposed shape of that tensor (given in the type hint) does not change, even when the operation does change the actual shape:
class Doc(BaseDocument):
tensor: TorchTensor[3,1]
d = Doc(tensor=torch.rand(size=(3, 1)))
d_t = d.tensor.transpose(0, 1)
print(d_t)
print(d_t.shape)
TorchTensor[3, 1]([[0.7769, 0.8053, 0.0161]])
torch.Size([1, 3])
I see two options to tackle this:
- override
__torch_function__ to correctly assign the current shape to the tensor class after every operation. This would make it that in the example above, the first print would produce TorchTensor[1, 3]([[0.7769, 0.8053, 0.0161]]). Problem with this: This transformation has to happen at every torch operation, which seems prone for things to go wrong
- override
__torch_function__ to return a torch.Tensor instead of a TorchTensor. The challenge with this is that in the current state, this would make our type system useless outside of Document. For example, the following would not type check:
def my_helper(t: TorchTensor[512]):
...
class Doc(BaseDocument):
tensor: TorchTensor[512]
d = Doc(tensor=torch.rand(512))
t = d.tensor + d.tensor # t is now torch.Tensor
my_helper(t) # but this wants TorchTensor[512]
To make this make sense again, we should take a look at TorchTyping and see how they achieve proper typing despite the data being torch.Tensor.
Big advantage of this: As soon as any operations occur on a TorchTensor, it turns into a torch.Tensor, meaning there is no opportunity for us to f*ck things up inside of a model.
Preliminary conclusion: Let's do it properly through option 2.
When operations on a paremetrized tensor are performed, the supposed shape of that tensor (given in the type hint) does not change, even when the operation does change the actual shape:
I see two options to tackle this:
__torch_function__to correctly assign the current shape to the tensor class after every operation. This would make it that in the example above, the first print would produceTorchTensor[1, 3]([[0.7769, 0.8053, 0.0161]]). Problem with this: This transformation has to happen at every torch operation, which seems prone for things to go wrong__torch_function__to return atorch.Tensorinstead of aTorchTensor. The challenge with this is that in the current state, this would make our type system useless outside of Document. For example, the following would not type check:To make this make sense again, we should take a look at TorchTyping and see how they achieve proper typing despite the data being
torch.Tensor.Big advantage of this: As soon as any operations occur on a
TorchTensor, it turns into atorch.Tensor, meaning there is no opportunity for us to f*ck things up inside of a model.Preliminary conclusion: Let's do it properly through option 2.