Skip to content

Commit

Permalink
Merge branch 'zinggAI:main' into sania
Browse files Browse the repository at this point in the history
  • Loading branch information
sania-16 authored Nov 19, 2024
2 parents 2b229be + 7af8b70 commit f335f67
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 10 deletions.
8 changes: 7 additions & 1 deletion common/core/src/main/java/zingg/common/core/block/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ public void estimateElimCount(Canopy<R> c, long elimCount) {
long least = Long.MAX_VALUE;
int maxElimination = 0;
Canopy<R>best = null;
for (FieldDefinition field : fieldsOfInterest) {
List<FieldDefinition> adjustedFieldOfInterestList = getFieldOfInterestList(fieldsOfInterest, node);
for (FieldDefinition field : adjustedFieldOfInterestList) {
if (LOG.isDebugEnabled()){
LOG.debug("Trying for " + field + " with data type " + field.getDataType() + " and real dt "
+ getFeatureFactory().getDataTypeFromString(field.getDataType()));
Expand Down Expand Up @@ -404,6 +405,11 @@ public void printTree(Tree<Canopy<R>> tree,
}
}

public List<FieldDefinition> getFieldOfInterestList(List<FieldDefinition> fieldDefinitions, Canopy<R> node) {
FieldDefinitionStrategy<R> fieldDefinitionStrategy = new DefaultFieldDefinitionStrategy<R>();
return fieldDefinitionStrategy.getAdjustedFieldDefinitions(fieldDefinitions, node);
}

public abstract FeatureFactory<T> getFeatureFactory();


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package zingg.common.core.block;

import zingg.common.client.FieldDefinition;

import java.util.List;

public class DefaultFieldDefinitionStrategy<R> implements FieldDefinitionStrategy<R> {
@Override
public List<FieldDefinition> getAdjustedFieldDefinitions(List<FieldDefinition> fieldDefinitions, Canopy<R> node) {
//returning fieldDefinitions
//as it is here
return fieldDefinitions;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package zingg.common.core.block;

import zingg.common.client.FieldDefinition;

import java.util.List;

public interface FieldDefinitionStrategy<R> {
List<FieldDefinition> getAdjustedFieldDefinitions(List<FieldDefinition> fieldDefinitions, Canopy<R> node);
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,17 @@ public TestExecutorsSingle(){

@Override
public List<ExecutorTester<S, D, R, C, T>> getExecutors() throws ZinggClientException, IOException{
TrainingDataFinder<S, D, R, C, T> tdf = getTrainingDataFinder();
Labeller<S, D, R, C, T> labeler = getLabeller();

getBaseExecutors();
getAdditionalExecutors();

return executorTesterList;
}

public void getBaseExecutors() throws ZinggClientException, IOException{

TrainingDataFinder<S, D, R, C, T> tdf = getTrainingDataFinder();
Labeller<S, D, R, C, T> labeler = getLabeller();
executorTesterList.add(new FtdAndLabelCombinedExecutorTester<S, D, R, C, T>(tdf, new TrainingDataFinderValidator<S, D, R, C, T>(tdf), getConfigFile(),
labeler, new LabellerValidator<S, D, R, C, T>(labeler), modelId, getDFObjectUtil()));

Expand All @@ -38,13 +47,16 @@ public List<ExecutorTester<S, D, R, C, T>> getExecutors() throws ZinggClientExce
VerifyBlocking<S, D, R, C, T> verifyBlocker = getVerifyBlocker();
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(verifyBlocker, new BlockerValidator<S, D, R, C, T>(verifyBlocker),getConfigFile(),modelId,getDFObjectUtil()));

}

public void getAdditionalExecutors() throws ZinggClientException, IOException{

Matcher<S, D, R, C, T> matcher = getMatcher();
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(matcher,new MatcherValidator<S, D, R, C, T>(matcher),getConfigFile(),modelId,getDFObjectUtil()));

Linker<S, D, R, C, T> linker = getLinker();
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(linker,new LinkerValidator<S, D, R, C, T>(linker),getLinkerConfigFile(),modelId,getDFObjectUtil()));

return executorTesterList;

}

public abstract String getConfigFile();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
package zingg.common.core.executor.validate;

import static org.junit.jupiter.api.Assertions.assertTrue;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import zingg.common.client.ZFrame;
import zingg.common.client.ZinggClientException;
import zingg.common.client.util.ColName;
import zingg.common.core.executor.FindAndLabeller;

public class FindAndLabelValidator<S, D, R, C, T> extends ExecutorValidator<S, D, R, C, T> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import zingg.common.client.IArguments;
import zingg.common.client.ZinggClientException;
import zingg.common.core.executor.Trainer;

Expand Down

0 comments on commit f335f67

Please sign in to comment.