Skip to content
Snippets Groups Projects
Commit b12e9100 authored by Floris Berendsen's avatar Floris Berendsen
Browse files

ENH: added tests for connecting diverse metrics and optimizers

parent b4e11563
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,8 @@ class Optimizer3rdPartyBase{
public:
virtual int SetMetric(Metric3rdPartyBase*) = 0;
virtual int Optimize() = 0;
protected:
Metric3rdPartyBase* theMetric;
};
class Metric4rdPartyBase{
......@@ -34,11 +36,37 @@ public:
class GDOptimizer3rdParty : public Optimizer3rdPartyBase {
public:
GDOptimizer3rdParty();
~GDOptimizer3rdParty();
virtual int SetMetric(Metric3rdPartyBase*);
virtual int Optimize();
Metric3rdPartyBase* metric;
};
GDOptimizer3rdParty::GDOptimizer3rdParty()
{
this->theMetric = nullptr;
}
GDOptimizer3rdParty::~GDOptimizer3rdParty()
{
}
int GDOptimizer3rdParty::SetMetric(Metric3rdPartyBase* metric)
{
this->theMetric = metric;
return 0;
}
int GDOptimizer3rdParty::Optimize()
{
if (this->theMetric != nullptr)
{
std::cout << "GDOptimizer3rdParty->Optimize():" << std::endl;
std::cout << " theMetric->GetValue():" << theMetric->GetValue() << std::endl;
std::cout << " theMetric->GetDerivative():" << theMetric->GetDerivative() << std::endl;
}
return 0;
}
class SSDMetric4rdParty : public Metric4rdPartyBase {
public:
virtual int GetCost() { return 3; };
......@@ -96,7 +124,7 @@ public:
class OptimizerDerivativeInterface {
public:
virtual int SetMeticDerivativeComponentInterface(MetricDerivativeInterface*) = 0;
virtual int SetMetricDerivativeComponentInterface(MetricDerivativeInterface*) = 0;
};
class OptimizerUpdateInterface {
......@@ -132,6 +160,75 @@ int SSDMetric3rdPartyComponent::GetValue()
return this->theImplementation->GetValue();
};
class Metric3rdPartyWrapper : public Metric3rdPartyBase {
public:
void SetMetricValueComponent(MetricValueInterface*);
void SetMetricDerivativeComponent(MetricDerivativeInterface*);
virtual int GetValue();
virtual int GetDerivative();
private:
MetricValueInterface* metricval;
MetricDerivativeInterface* metricderiv;
};
void Metric3rdPartyWrapper::SetMetricValueComponent(MetricValueInterface* metricValueComponent)
{
this->metricval = metricValueComponent;
}
int Metric3rdPartyWrapper::GetValue()
{
return this->metricval->GetValue();
}
void Metric3rdPartyWrapper::SetMetricDerivativeComponent(MetricDerivativeInterface* metricDerivativeComponent)
{
this->metricderiv = metricDerivativeComponent;
}
int Metric3rdPartyWrapper::GetDerivative()
{
return this->metricderiv->GetDerivative();
}
class GDOptimizer3rdPartyComponent : public ComponentBase, public OptimizerValueInterface, public OptimizerDerivativeInterface, public OptimizerUpdateInterface
{
public:
GDOptimizer3rdPartyComponent();
~GDOptimizer3rdPartyComponent();
GDOptimizer3rdParty* theImplementation;
Metric3rdPartyWrapper* MetricObject;
int SetMetricValueComponentInterface(MetricValueInterface*);
int SetMetricDerivativeComponentInterface(MetricDerivativeInterface*);
int Update();
};
GDOptimizer3rdPartyComponent::GDOptimizer3rdPartyComponent()
{
this->theImplementation = new GDOptimizer3rdParty();
this->MetricObject = new Metric3rdPartyWrapper();
}
GDOptimizer3rdPartyComponent::~GDOptimizer3rdPartyComponent()
{
delete this->theImplementation;
delete this->MetricObject;
}
int GDOptimizer3rdPartyComponent::SetMetricValueComponentInterface(MetricValueInterface* component)
{
this->MetricObject->SetMetricValueComponent(component);
return 0;
}
int GDOptimizer3rdPartyComponent::SetMetricDerivativeComponentInterface(MetricDerivativeInterface* component)
{
this->MetricObject->SetMetricDerivativeComponent(component);
return 0;
}
int GDOptimizer3rdPartyComponent::Update()
{
this->theImplementation->SetMetric(this->MetricObject);
return this->theImplementation->Optimize(); // 3rd party specific call
}
class SSDMetric4rdPartyComponent : public ComponentBase, public MetricValueInterface {
......@@ -181,7 +278,6 @@ public:
~GDOptimizer4rdPartyComponent();
GDOptimizer4rdParty* theImplementation;
Metric4rdPartyWrapper* MetricObject;
MetricDerivativeInterface* metricDerivativeInterface;
int SetMetricValueComponentInterface(MetricValueInterface*);
int Update();
};
......@@ -205,7 +301,7 @@ int GDOptimizer4rdPartyComponent::SetMetricValueComponentInterface(MetricValueIn
int GDOptimizer4rdPartyComponent::Update()
{
this->theImplementation->SetMetric(this->MetricObject);
return this->theImplementation->DoOptimization();
return this->theImplementation->DoOptimization(); // 4rd party specific call
}
int main () {
......@@ -219,13 +315,15 @@ int main () {
MetricDerivativeInterface* derIF = dynamic_cast<MetricDerivativeInterface*> (metric3p);
std::cout << derIF->GetDerivative() << std::endl;
}
/************************/
/************ Connect metric4p to optimizer4p ***********
* expected: ok
*/
{
SSDMetric4rdPartyComponent* tempmetric4p = new SSDMetric4rdPartyComponent();
ComponentBase* metric4p = tempmetric4p;
ComponentBase* metric4p = tempmetric4p; // type returned by our component factory
GDOptimizer4rdPartyComponent* tempOptimizer4p = new GDOptimizer4rdPartyComponent();
ComponentBase* optimizer4p = tempOptimizer4p;
ComponentBase* optimizer4p = tempOptimizer4p; // type returned by our component factory
MetricValueInterface* metvalIF = dynamic_cast<MetricValueInterface*> (metric4p);
if (!metvalIF)
......@@ -250,13 +348,16 @@ int main () {
// Update the optimizer component
opUpdIF->Update();
}
/************ Connect metric3p to optimizer4p ***********
* expected: ok
* optimizer4p will only use/have access to the GetValue interface of metric3p
*/
{
SSDMetric3rdPartyComponent* tempmetric3p = new SSDMetric3rdPartyComponent();
ComponentBase* metric3p = tempmetric3p;
ComponentBase* metric3p = tempmetric3p; // type returned by our component factory
GDOptimizer4rdPartyComponent* tempOptimizer4p = new GDOptimizer4rdPartyComponent();
ComponentBase* optimizer4p = tempOptimizer4p;
ComponentBase* optimizer4p = tempOptimizer4p; // type returned by our component factory
MetricValueInterface* metvalIF = dynamic_cast<MetricValueInterface*> (metric3p);
if (!metvalIF)
......@@ -281,6 +382,54 @@ int main () {
// Update the optimizer component
opUpdIF->Update();
}
/************ Connect metric4p to optimizer3p ***********
* expected: fail
* optimizer3p needs a metric with GetDerivative which metric4p doesn't have
*/
{
SSDMetric4rdPartyComponent* tempmetric4p = new SSDMetric4rdPartyComponent();
ComponentBase* metric4p = tempmetric4p; // type returned by our component factory
GDOptimizer3rdPartyComponent* tempOptimizer3p = new GDOptimizer3rdPartyComponent();
ComponentBase* optimizer3p = tempOptimizer3p; // type returned by our component factory
MetricValueInterface* metvalIF = dynamic_cast<MetricValueInterface*> (metric4p);
if (!metvalIF)
{
std::cout << "metric4p has no MetricValueInterface" << std::endl;
}
OptimizerValueInterface* opValIF = dynamic_cast<OptimizerValueInterface*> (optimizer3p);
if (!opValIF)
{
std::cout << "optimizer4p has no OptimizerValueInterface" << std::endl;
}
// connect value interfaces
opValIF->SetMetricValueComponentInterface(metvalIF);
MetricDerivativeInterface* metderivIF = dynamic_cast<MetricDerivativeInterface*> (metric4p);
if (!metderivIF)
{
std::cout << "metric4p has no MetricDerivativeInterface" << std::endl;
}
OptimizerDerivativeInterface* opDerivIF = dynamic_cast<OptimizerDerivativeInterface*> (optimizer3p);
if (!opDerivIF)
{
std::cout << "optimizer4p has no OptimizerDerivativeInterface" << std::endl;
}
// connect derivative interfaces
opDerivIF->SetMetricDerivativeComponentInterface(metderivIF);
OptimizerUpdateInterface* opUpdIF = dynamic_cast<OptimizerUpdateInterface*> (optimizer3p);
if (!opValIF)
{
std::cout << "optimizer3p has no OptimizerUpdateInterface" << std::endl;
}
// Update the optimizer component
// opUpdIF->Update(); // will fail since the metric does'nt have GetDerivative()
}
return 0;
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment