Skip to content

Commit

Permalink
Abstract parameters class (#6)
Browse files Browse the repository at this point in the history
* make Parameters a subclass-able abstract

* add handling for unordered XML

* update missed test
  • Loading branch information
jwfraustro authored Dec 12, 2023
1 parent d2ba22f commit 8ced82a
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 46 deletions.
85 changes: 47 additions & 38 deletions tests/uws/uws_models_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Tests for the XML serialization of UWS elements"""

from datetime import timezone as tz
from typing import Optional
from unittest import TestCase
from xml.etree.ElementTree import canonicalize

from lxml import etree
from pydantic_xml import element

from vo_models.xml.uws import (
ErrorSummary,
Expand Down Expand Up @@ -238,6 +240,15 @@ def test_write_to_xml(self):
class TestParametersElement(TestCase):
"""Test the UWS Parameters element"""

class TestParameters(Parameters):
"""A test subclass of Parameters.
This subclass is used to test the Parameters element, which is an abstract type."""

param1: Optional[Parameter] = element(tag="parameter", default=None)
param2: Optional[Parameter] = element(tag="parameter", default=None)
param3: Optional[Parameter] = element(tag="parameter", default=None)

test_parameters_xml = (
f"<uws:parameters {UWS_NAMESPACE_HEADER}>"
'<uws:parameter byReference="false" isPost="false" id="param1">value1</uws:parameter>'
Expand All @@ -249,26 +260,24 @@ class TestParametersElement(TestCase):
def test_read_from_xml(self):
"""Test reading from XML"""

parameters = Parameters.from_xml(self.test_parameters_xml)
self.assertEqual(len(parameters.parameter), 3)
parameters = self.TestParameters.from_xml(self.test_parameters_xml)
self.assertEqual(len(parameters.dict()), 3)

self.assertEqual(parameters.parameter[0].id, "param1")
self.assertEqual(parameters.parameter[1].id, "param2")
self.assertEqual(parameters.parameter[2].id, "param3")
self.assertEqual(parameters.param1.id, "param1")
self.assertEqual(parameters.param2.id, "param2")
self.assertEqual(parameters.param3.id, "param3")

self.assertEqual(parameters.parameter[0].value, "value1")
self.assertEqual(parameters.parameter[1].value, "value2")
self.assertEqual(parameters.parameter[2].value, "value3")
self.assertEqual(parameters.param1.value, "value1")
self.assertEqual(parameters.param2.value, "value2")
self.assertEqual(parameters.param3.value, "value3")

def test_write_to_xml(self):
"""Test writing to XML"""

parameters_element = Parameters(
parameter=[
Parameter(id="param1", value="value1"),
Parameter(id="param2", value="value2"),
Parameter(id="param3", value="value3"),
]
parameters_element = self.TestParameters(
param1=Parameter(id="param1", value="value1"),
param2=Parameter(id="param2", value="value2"),
param3=Parameter(id="param3", value="value3"),
)
parameters_xml = parameters_element.to_xml(skip_empty=True, encoding=str)

Expand All @@ -280,12 +289,10 @@ def test_write_to_xml(self):
def test_validate(self):
"""Test validation against XML schema"""

parameters = Parameters(
parameter=[
Parameter(id="param1", value="value1"),
Parameter(id="param2", value="value2"),
Parameter(id="param3", value="value3"),
]
parameters = self.TestParameters(
param1=Parameter(id="param1", value="value1"),
param2=Parameter(id="param2", value="value2"),
param3=Parameter(id="param3", value="value3"),
)
parameters_xml = etree.fromstring(parameters.to_xml(skip_empty=True, encoding=str))
uws_schema.assertValid(parameters_xml)
Expand All @@ -294,6 +301,12 @@ def test_validate(self):
class TestJobSummaryElement(TestCase):
"""Test the UWS JobSummary element"""

class TestParameters(Parameters):
"""A test subclass of Parameters."""

param1: Optional[Parameter] = element(tag="parameter", default=None)
param2: Optional[Parameter] = element(tag="parameter", default=None)

job_summary_xml = (
f'<uws:job {UWS_NAMESPACE_HEADER} version="1.1">'
"<uws:jobId>jobId1</uws:jobId>"
Expand All @@ -318,7 +331,7 @@ class TestJobSummaryElement(TestCase):
def test_read_from_xml(self):
"""Test reading from XML"""

job_summary = JobSummary[Parameters].from_xml(self.job_summary_xml)
job_summary = JobSummary[self.TestParameters].from_xml(self.job_summary_xml)
self.assertEqual(job_summary.job_id, "jobId1")
self.assertEqual(job_summary.run_id, "runId1")
self.assertEqual(job_summary.owner_id, "ownerId1")
Expand All @@ -329,19 +342,19 @@ def test_read_from_xml(self):
self.assertEqual(job_summary.end_time, UTCTimestamp(1900, 1, 1, 1, 1, 1, tzinfo=tz.utc))
self.assertEqual(job_summary.execution_duration, 0)
self.assertEqual(job_summary.destruction, UTCTimestamp(1900, 1, 1, 1, 1, 1, tzinfo=tz.utc))
self.assertEqual(len(job_summary.parameters.parameter), 2)
self.assertEqual(job_summary.parameters.parameter[0].id, "param1")
self.assertEqual(job_summary.parameters.parameter[1].id, "param2")
self.assertEqual(job_summary.parameters.parameter[0].value, "value1")
self.assertEqual(job_summary.parameters.parameter[1].value, "value2")
self.assertEqual(len(job_summary.parameters.dict()), 2)
self.assertEqual(job_summary.parameters.param1.id, "param1")
self.assertEqual(job_summary.parameters.param2.id, "param2")
self.assertEqual(job_summary.parameters.param1.value, "value1")
self.assertEqual(job_summary.parameters.param2.value, "value2")
self.assertEqual(len(job_summary.results.results), 0)
self.assertEqual(job_summary.error_summary, None)
self.assertEqual(job_summary.job_info[0], "jobInfo1")

def test_write_to_xml(self):
"""Test writing to XML"""

job_summary = JobSummary[Parameters](
job_summary = JobSummary[self.TestParameters](
job_id="jobId1",
run_id="runId1",
owner_id="ownerId1",
Expand All @@ -352,11 +365,9 @@ def test_write_to_xml(self):
end_time=UTCTimestamp(1900, 1, 1, 1, 1, 1, tzinfo=tz.utc),
execution_duration=0,
destruction=UTCTimestamp(1900, 1, 1, 1, 1, 1, tzinfo=tz.utc),
parameters=Parameters(
parameter=[
Parameter(id="param1", value="value1"),
Parameter(id="param2", value="value2"),
]
parameters=self.TestParameters(
param1=Parameter(id="param1", value="value1"),
param2=Parameter(id="param2", value="value2"),
),
results=Results(),
job_info=["jobInfo1"],
Expand All @@ -371,7 +382,7 @@ def test_write_to_xml(self):
def test_validate(self):
"""Validate against the schema"""

job_summary = JobSummary[Parameters](
job_summary = JobSummary[self.TestParameters](
job_id="jobId1",
run_id="runId1",
owner_id="ownerId1",
Expand All @@ -382,11 +393,9 @@ def test_validate(self):
end_time=None,
execution_duration=0,
destruction=UTCTimestamp(1900, 1, 1, 1, 1, 1, tzinfo=tz.utc),
parameters=Parameters(
parameter=[
Parameter(id="param1", value="value1"),
Parameter(id="param2", value="value2"),
]
parameters=self.TestParameters(
param1=Parameter(id="param1", value="value1"),
param2=Parameter(id="param2", value="value2"),
),
results=Results(results=[ResultReference(id="result1")]),
error_summary=None,
Expand Down
23 changes: 15 additions & 8 deletions vo_models/xml/uws/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from pydantic import field_validator
from pydantic_xml import BaseXmlModel, attr, element

from vo_models.xml.voresource.types import UTCTimestamp
from vo_models.xml.uws.types import ErrorType, ExecutionPhase, UWSVersion
from vo_models.xml.voresource.types import UTCTimestamp
from vo_models.xml.xlink import XlinkType

NSMAP = {
Expand Down Expand Up @@ -54,13 +54,20 @@ def validate_value(cls, value): # pylint: disable=no-self-argument


class Parameters(BaseXmlModel, tag="parameters", ns="uws", nsmap=NSMAP):
"""A list of UWS Job parameters.
Elements:
parameter (Parameter): a UWS Job parameter.
"""

parameter: Optional[list[Parameter]] = element(name="parameter", default_factory=list)
"""An abstract holder of UWS parameters."""

def __init__(__pydantic_self__, **data) -> None: # pylint: disable=no-self-argument
# during init -- especially if reading from xml -- we may not get the parameters in the order
# pydantic-xml expects. This will remap the dict with keys based on the parameter id.
parameter_vals = [val for val in data.values() if val is not None]
remapped_vals = {}
for param in parameter_vals:
if isinstance(param, dict):
remapped_vals[param["id"]] = Parameter(**param)
else:
remapped_vals[param.id] = param
data = remapped_vals
super().__init__(**data)


class ErrorSummary(BaseXmlModel, tag="errorSummary", ns="uws", nsmap=NSMAP):
Expand Down

0 comments on commit 8ced82a

Please sign in to comment.