Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a problem with using torchTT.interpolate.dmrg_cross() #18

Closed
Boyuan-Shi-0607 opened this issue Jan 22, 2024 · 2 comments
Closed

a problem with using torchTT.interpolate.dmrg_cross() #18

Boyuan-Shi-0607 opened this issue Jan 22, 2024 · 2 comments
Assignees

Comments

@Boyuan-Shi-0607
Copy link

I tried to use interpolate.dmrg_cross() from a numpy array.

Generate a random array:
import numpy as np
import torch as tn
import torchtt as tntt
test = np.random.rand(10, 10, 10)

Define the function in this way:

def func(args):                                        
    return tn.tensor([tn.from_numpy(test)[*args[0]]], dtype=tn.complex128)         

When I tried to use dmg_cross to interpolate this exemplary random tensor,

N = list(test.shape)                                   
x = tntt.interpolate.dmrg_cross(func, N, eps=10**(-8)) 

I got the error message:

Traceback (most recent call last):
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevconsole.py", line 364, in runcode
    coro = func()
           ^^^^^^
  File "<input>", line 1, in <module>
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/boyuanshi/Desktop/second_project/equilibrium_v2/Screened_Interactions_Plot.py", line 61, in <module>
    x = tntt.interpolate.dmrg_cross(func, N, eps=10**(-8))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/boyuanshi/.conda/envs/Desktop/lib/python3.11/site-packages/torchTT-2.0-py3.11-macosx-10.9-x86_64.egg/torchtt/interpolate.py", line 514, in dmrg_cross
    supercore = tn.reshape(function(eval_index),[rank[k],N[k],N[k+1],rank[k+2]])
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[1, 10, 10, 2]' is invalid for input of size 1

I am quite confused why it is the case?

@Boyuan-Shi-0607
Copy link
Author

It seems that I magically solve the issue by adding 0./(2+tn.exp(tn.sum(I, 1))) to the function and also change the data type to tn.float64. Not known why.

@ion-g-ion
Copy link
Owner

ion-g-ion commented Jan 25, 2024

Hi,

The output of the function handle should be float and of shape (m,) for an input of shape (m,d).
The provided hanndle returns a (1,) torch tensor. It can be modified as

def func(args):                                        
   return tn.tensor([tn.from_numpy(test)[*a] for a in args], dtype=tn.float64)    

Regarding the complex numbers, right now only the real p is approximated. One alternative is to cross approximate real and imaginary separately. I would need to do some work to support complex numbers in all operations.

@ion-g-ion ion-g-ion self-assigned this Jan 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants