Skip to content

Commit

Permalink
Merge pull request #47 from michaelosthege/pytensor-2.18-compat
Browse files Browse the repository at this point in the history
Compatibility fix for PyTensor >=2.18.1
  • Loading branch information
michaelosthege committed Nov 26, 2023
2 parents 37ae0fe + 5e425a6 commit b86666a
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 16 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ dependencies:
- psutil
- pip:
- betterproto[compiler]==2.0.0b6
- pymc==5.8.0
- pymc==5.10.0
2 changes: 1 addition & 1 deletion pytensor_federated/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
from .service import ArraysToArraysService, ArraysToArraysServiceClient
from .signatures import ComputeFunc, LogpFunc, LogpGradFunc

__version__ = "1.0.0"
__version__ = "1.0.1"
12 changes: 3 additions & 9 deletions pytensor_federated/op_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Apply, Variable, apply_depends_on
from pytensor.graph.features import ReplaceValidate
from pytensor.graph.op import Op, OutputStorageType, ParamsInputType
from pytensor.graph.op import Op, OutputStorageType
from pytensor.graph.rewriting.basic import GraphRewriter

from .utils import get_useful_event_loop
Expand All @@ -19,10 +19,9 @@ def perform(
node: Apply,
inputs: Sequence[Any],
output_storage: OutputStorageType,
params: ParamsInputType = None,
) -> None:
loop = get_useful_event_loop()
coro = self.perform_async(node, inputs, output_storage, params)
coro = self.perform_async(node, inputs, output_storage)
loop.run_until_complete(coro)
return

Expand All @@ -31,7 +30,6 @@ async def perform_async(
node: Apply,
inputs: Sequence[Any],
output_storage: OutputStorageType,
params: ParamsInputType = None,
) -> None:
raise NotImplementedError()

Expand All @@ -57,7 +55,6 @@ async def perform_async(
node: Apply,
inputs: Sequence[Any],
output_storage: OutputStorageType,
params: ParamsInputType = None,
) -> None:
outs = await self.__async_fn(*inputs)
if not isinstance(outs, (list, tuple)):
Expand Down Expand Up @@ -112,7 +109,6 @@ def perform(
node: Apply,
inputs: Sequence[Any],
output_storage: OutputStorageType,
params: ParamsInputType = None,
) -> None:
# Create coroutines the performing the taks of each child node
coros = []
Expand All @@ -122,9 +118,7 @@ def perform(
ito = ifrom + apply.nin
oto = ofrom + apply.nout
coros.append(
apply.op.perform_async(
apply, inputs[ifrom:ito], output_storage[ofrom:oto], params=params
)
apply.op.perform_async(apply, inputs[ifrom:ito], output_storage[ofrom:oto])
)
ifrom = ito
ofrom = oto
Expand Down
6 changes: 1 addition & 5 deletions pytensor_federated/wrapper_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytensor.tensor as at
from pytensor.compile.ops import FromFunctionOp
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op, OutputStorageType, ParamsInputType
from pytensor.graph.op import Op, OutputStorageType

from .op_async import AsyncFromFunctionOp, AsyncOp
from .signatures import ComputeFunc, LogpFunc, LogpGradFunc
Expand Down Expand Up @@ -63,7 +63,6 @@ def perform(
node: Apply,
inputs: Sequence[np.ndarray],
output_storage: OutputStorageType,
params: ParamsInputType = None,
) -> None:
logp = self._logp_func(*inputs)
output_storage[0][0] = logp
Expand All @@ -76,7 +75,6 @@ async def perform_async(
node: Apply,
inputs: Sequence[Any],
output_storage: OutputStorageType,
params: ParamsInputType = None,
) -> None:
logp = await self._logp_func(*inputs)
output_storage[0][0] = logp
Expand Down Expand Up @@ -111,7 +109,6 @@ def perform(
node: Apply,
inputs: Sequence[np.ndarray],
output_storage: OutputStorageType,
params: ParamsInputType = None,
) -> None:
logp, gradient = self._logp_grad_func(*inputs)
output_storage[0][0] = logp
Expand Down Expand Up @@ -141,7 +138,6 @@ async def perform_async(
node: Apply,
inputs: Sequence[Any],
output_storage: OutputStorageType,
params: ParamsInputType = None,
) -> None:
logp, gradient = await self._logp_grad_func(*inputs)
output_storage[0][0] = logp
Expand Down

0 comments on commit b86666a

Please sign in to comment.