"""
Methods to load data from a spreadsheet (Excel) file database.
"""
from datetime import timedelta
from typing import BinaryIO, List
import numpy as np
import pandas as pd
from fastapi_camelcase import CamelModel
from pydantic import Field, ValidationError
from .. import schemas as s
def _parse_node(data: dict) -> s.AllNodes:
"""
Helper function to manage node polymorphism.
Args:
data (dict): the node in dictionary format
Returns:
AllNodes: the node in SpaceNet format
"""
# loop through candidate model classes
for model_cls in [s.SurfaceNode, s.OrbitalNode, s.LagrangeNode]:
try:
return model_cls(**data)
except ValidationError:
pass
raise ValueError("No valid node type found")
def _parse_edge(data: dict) -> s.AllEdges:
"""
Helper function to manage edge polymorphism.
Args:
data (dict): the edge in dictionary format
Returns:
AllEdges: the edge in SpaceNet format
"""
# loop through candidate model classes
for model_cls in [s.SurfaceEdge, s.SpaceEdge, s.FlightEdge]:
try:
return model_cls(**data)
except ValidationError:
pass
raise ValueError("No valid edge type found for " + str(data))
def _parse_resource(data: dict) -> s.AllResources:
"""
Helper function to manage resource polymorphism.
Args:
data (dict): the resource in dictionary format
Returns:
AllResources: the resource in SpaceNet format
"""
# fix `description` null value
if pd.isna(data["description"]):
data["description"] = None
# loop through candidate model classes
for model_cls in [s.ContinuousResource, s.DiscreteResource]:
try:
return model_cls(**data)
except ValidationError:
pass
raise ValueError("No valid resource found for " + str(data))
def _parse_element(data: dict) -> s.AllElements:
"""
Helper function to manage element polymorphism.
Args:
data (dict): the element in dictionary format
Returns:
AllElements: the element in SpaceNet format
"""
# fix `current_state_index` null value
if pd.isna(data["current_state_index"]):
data["current_state_index"] = None
# loop through candidate model classes
for model_cls in [
s.Element,
s.ResourceContainer,
s.ElementCarrier,
s.HumanAgent,
s.RoboticAgent,
s.PropulsiveVehicle,
s.SurfaceVehicle,
]:
try:
return model_cls(**data)
except ValidationError:
pass
raise ValueError("No valid element found for " + str(data))
def _parse_demand_model(
data: dict,
) -> s.AllElementDemandModels:
"""
Helper function to manage demand model polymorphism.
Args:
data (dict): the demand model in dictionary format
Returns:
AllElementDemandModels: the demand model in SpaceNet format
"""
# loop through candidate model classes
for model_cls in [
s.TimedImpulseDemandModel,
s.RatedDemandModel,
s.SparingByMassDemandModel,
]:
try:
return model_cls(**data)
except ValidationError:
pass
raise ValueError("No valid demand model found for " + str(data))
[docs]class ModelDatabase(CamelModel):
"""
Database stores models for nodes, edges, resources, demand models, and elements.
"""
nodes: List[s.AllNodes] = Field([], description="List of nodes")
edges: List[s.AllEdges] = Field([], description="List of edges")
resources: List[s.AllResources] = Field([], description="List of resources")
demand_models: List[s.AllElementDemandModels] = Field(
[], description="List of demand models"
)
elements: List[s.AllElements] = Field([], description="List of elements")
[docs] def get_node(self, name: str) -> s.AllNodes:
"""
Gets the node matching a given name.
Args:
name (str): the node name
Returns:
str: The matchine node if it exsits, otherwise `None`.
"""
return next(filter(lambda o: o.name == name, self.nodes))
[docs] def get_edge(self, name: str) -> s.AllEdges:
"""
Gets the edge matching a given name.
Args:
name (str): the edge name
Returns:
str: The matchine edge if it exsits, otherwise `None`.
"""
return next(filter(lambda o: o.name == name, self.edges))
[docs] def get_resource(self, name):
"""
Gets the resource matching a given name.
Args:
name (str): the resource name
Returns:
str: The matchine resource if it exsits, otherwise `None`.
"""
return next(filter(lambda o: o.name == name, self.resources))
[docs] def instantiate_element(
self, cls, template_name: str, prefix: str = None, suffix: str = None
) -> s.AllInstElements:
"""
Instantiates an element for a given template.
Args:
cls: the element class
template_name (str): the template name
prefix (str, Optional): the prefix to apply before the instance name
suffix (str, Optional): the suffix to apply after the instance name
Returns:
s.AllInstElements: An instantiated element.
"""
return cls(
template_id=next(
filter(lambda o: o.name == template_name, self.elements)
).id,
name=(prefix + " | " if prefix is not None else "")
+ next(filter(lambda o: o.name == template_name, self.elements)).name
+ (" " + suffix if suffix is not None else ""),
)
def load_db(db_file: BinaryIO) -> ModelDatabase:
"""
Loads a database from file (Excel).
Args:
db_file (BinaryIO): the database file object
Returns:
ModelDatabase: the model database
"""
# read the nodes sheet
nodes = pd.read_excel(db_file, "nodes")
# parse the nodes, dropping the `id` field to generate a new uuid
nodes["model"] = (
nodes.drop("id", axis=1).apply(lambda r: _parse_node(r.to_dict()), axis=1)
if not nodes.empty
else None
)
# read the burns sheet
burns = pd.read_excel(db_file, "burns")
# parse the burns
burns["model"] = (
burns.apply(
lambda r: s.Burn(time=timedelta(days=r.time), delta_v=r.delta_v), axis=1
)
if not burns.empty
else None
)
# read the edges sheet
edges = pd.read_excel(db_file, "edges")
# add the `origin` field by matching with node ids
edges["origin"] = (
edges.origin_id.apply(lambda i: nodes[nodes.id == i].iloc[0].model.id)
if not edges.empty
else None
)
# add the `destination` field by matching with node ids
edges["destination"] = (
edges.destination_id.apply(lambda i: nodes[nodes.id == i].iloc[0].model.id)
if not edges.empty
else None
)
# convert the numeric duration (in days) to a Python timedelta
edges["duration"] = edges.duration.apply(lambda i: timedelta(days=i))
# add the `burns` field by matching with burn edge_ids
edges["burns"] = (
edges.id.apply(lambda i: burns[burns.edge_id == i].model.to_list())
if not (edges.empty or burns.empty)
else None
)
# parse the edges, dropping the `id` field to generate a new uuid
edges["model"] = (
edges.drop("id", axis=1).apply(lambda r: _parse_edge(r.to_dict()), axis=1)
if not edges.empty
else None
)
# read the resources sheet
resources = pd.read_excel(db_file, "resources")
# parse the resources, dropping the `id` field to generate a new uuid
resources["model"] = (
resources.drop("id", axis=1).apply(
lambda r: _parse_resource(r.to_dict()), axis=1
)
if not resources.empty
else None
)
# read the demands sheet
demands = pd.read_excel(db_file, "demands")
# parse the demands
demands["model"] = (
demands.apply(
lambda r: s.ResourceAmount(
resource=resources[resources.id == r.resource_id].iloc[0].model.id,
amount=r.amount,
)
if r.resource_id > 0 and pd.notna(r.amount)
else s.ResourceAmountRate(
resource=resources[resources.id == r.resource_id].iloc[0].model.id,
rate=r.rate,
)
if r.resource_id > 0 and pd.notna(r.rate)
else s.GenericResourceAmount(
class_of_supply=-r.resource_id,
environment=s.Environment.UNPRESSURIZED,
amount=r.amount,
)
if pd.notna(r.amount)
else s.GenericResourceAmountRate(
class_of_supply=-r.resource_id,
environment=s.Environment.UNPRESSURIZED,
rate=r.rate,
),
axis=1,
)
if not demands.empty
else None
)
# read the demand models sheet
demand_models = pd.read_excel(db_file, "demand_models")
# add the `demands` field by matching with demand `demand_model_id`
demand_models["demands"] = (
demand_models.id.apply(
lambda i: demands[demands.demand_model_id == i].model.to_list()
)
if not demand_models.empty
else None
)
# parse the demand models, dropping the `id` field to generate a new uuid
demand_models["model"] = (
demand_models.drop("id", axis=1).apply(
lambda r: _parse_demand_model(r.to_dict()), axis=1
)
if not demand_models.empty
else None
)
# read the parts sheet
parts = pd.read_excel(db_file, "parts")
# parse the parts
parts["model"] = (
parts.apply(
lambda r: s.Part(
resource=resources[resources.id == r.resource_id].iloc[0].model.id,
mean_time_to_failure=timedelta(hours=float(r.mean_time_to_failure))
if r.mean_time_to_failure > 0
else None,
mean_time_to_repair=timedelta(hours=float(r.mean_time_to_repair))
if r.mean_time_to_repair > 0
else None,
mass_to_repair=r.mass_to_repair,
quantity=r.quantity,
duty_cycle=r.duty_cycle,
),
axis=1,
)
if not parts.empty
else None
)
# read the states sheet
states = pd.read_excel(db_file, "states")
# add the `demand_models` field
states["demand_models"] = (
states.id.apply(
lambda i: [
s.InstTimedImpulseDemandModel(name=m.name, template_id=m.id)
if m.type == s.DemandModelType.TIMED_IMPULSE
else s.InstRatedDemandModel(name=m.name, template_id=m.id)
if m.type == s.DemandModelType.RATED
else s.InstSparingByMassDemandModel(name=m.name, template_id=m.id)
for m in demand_models[demand_models.state_id == i].model.to_list()
]
)
if not states.empty
else None
)
# parse the states, dropping the `id` field to generate a new uuid
states["model"] = (
states.drop("id", axis=1).apply(lambda r: s.State(**r.to_dict()), axis=1)
if not states.empty
else None
)
# read the resource contents sheet
contents = pd.read_excel(db_file, "contents")
# parse the contents
contents["model"] = (
contents.apply(
lambda r: s.ResourceAmount(
resource=resources[resources.id == r.resource_id].iloc[0].model.id,
amount=r.amount,
)
if r.resource_id > 0
else s.GenericResourceAmount(
class_of_supply=-r.resource_id,
environment=s.Environment.UNPRESSURIZED,
amount=r.amount,
),
axis=1,
)
if not contents.empty
else None
)
# read the elements sheet
elements = pd.read_excel(db_file, "elements")
# add the `parts` field
elements["parts"] = (
elements.id.apply(lambda i: parts[parts.element_id == i].model.to_list())
if not elements.empty
else None
)
# add the `states` field
elements["states"] = (
elements.id.apply(lambda i: states[states.element_id == i].model.to_list())
if not elements.empty
else None
)
# add the `contents` field
elements["contents"] = (
elements.id.apply(
lambda i: contents[contents.container_id == i].model.to_list()
)
if not elements.empty
else None
)
# coerce `max_crew` to an integer data type
elements["max_crew"] = elements.max_crew.astype("Int64")
# add the `current_state_index` field
elements["current_state_index"] = (
elements.apply(
lambda r: np.where(states[states.element_id == r.id].initial_state)[0][0]
if not (states[(states.element_id == r.id) & (states.initial_state)].empty)
else None,
axis=1,
)
if not (elements.empty or states.empty)
else None
)
# coerce `current_state_index` to an integer data type
elements["current_state_index"] = elements.current_state_index.astype("Int64")
# add the `fuel` field
elements["fuel"] = (
elements.apply(
lambda r: s.ResourceAmount(
resource=resources[resources.id == r.fuel_id].iloc[0].model.id,
amount=r.max_fuel,
)
if pd.notna(r.fuel_id) and r.fuel_id > 0
else s.GenericResourceAmount(
class_of_supply=-r.fuel_id,
environment=s.Environment.UNPRESSURIZED,
amount=r.max_fuel,
)
if pd.notna(r.fuel_id)
else None,
axis=1,
)
if not elements.empty
else None
)
# parse the elements, dropping the `id` field to generate a new uuid
elements["model"] = (
elements.drop("id", axis=1).apply(lambda r: _parse_element(r.to_dict()), axis=1)
if not elements.empty
else None
)
return ModelDatabase(
nodes=nodes.model.to_list(),
edges=edges.model.to_list(),
resources=resources.model.to_list(),
demand_models=demand_models.model.to_list(),
elements=elements.model.to_list(),
)