Skip to content

Instantly share code, notes, and snippets.

@shamsimam
Last active February 21, 2021 17:44
Show Gist options
  • Select an option

  • Save shamsimam/0ef3cba95c12acc0511504558e9c8ea5 to your computer and use it in GitHub Desktop.

Select an option

Save shamsimam/0ef3cba95c12acc0511504558e9c8ea5 to your computer and use it in GitHub Desktop.
Determine generic type parameters of an Implementation at Runtime
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import java.util.stream.Collectors;
/**
* @author <a href="mailto:shams.imam+github@gmail.com">Shams Imam</a>
*/
@SuppressWarnings({"rawtypes", "unchecked", "UnnecessaryLocalVariable"})
public class GenericResolution {
private static class Entry {
final Type type;
final Map<TypeVariable, Type> environment;
Entry(Type type, Map<TypeVariable, Type> environment) {
this.type = type;
this.environment = environment;
}
}
private static <T> T retrieveField(Object object, String fieldName) {
if (object != null) {
Class loopClass = object.getClass();
while (loopClass != null) {
Field loopField = null;
try {
loopField = loopClass.getDeclaredField(fieldName);
} catch (Exception ex) {
// ignore
}
if (loopField != null) {
try {
loopField.setAccessible(true);
Object fieldValue = loopField.get(object);
return (T) fieldValue;
} catch (Exception ex) {
throw new IllegalStateException("Field value could not be retrieved", ex);
}
} else {
loopClass = loopClass.getSuperclass();
}
}
throw new IllegalStateException("Field " + fieldName + " does not exist in " + object.getClass());
} else {
return null;
}
}
private static TypeVariable[] retrieveTypeVariables(ParameterizedType parameterizedType) {
try {
Class<?> aClass = (Class) parameterizedType.getRawType();
Object classRepository = retrieveField(aClass, "genericInfo");
TypeVariable[] typeParameters = retrieveField(classRepository, "typeParameters");
return typeParameters;
} catch (Exception ex) {
throw new IllegalStateException("Error in retrieving type variables", ex);
}
}
private static Type resolveType(Map<TypeVariable, Type> environment, Type type) {
if (type == null) {
return type;
} else if (type instanceof TypeVariable) {
return environment.get(type);
} else if (type instanceof ParameterizedType) {
ParameterizedType parameterizedType = (ParameterizedType) type;
Type rawType = resolveType(environment, parameterizedType.getRawType());
return rawType;
} else {
return type;
}
}
private static Map<TypeVariable, Type> computeParameterizedEnvironment(ParameterizedType parameterizedType, Map<TypeVariable, Type> currentEnv) {
Type[] typeArguments = parameterizedType.getActualTypeArguments();
TypeVariable[] typeVariables = retrieveTypeVariables(parameterizedType);
Map<TypeVariable, Type> newEnv = new HashMap<>();
for (int i = 0; i < typeVariables.length; i++) {
Type typeArgument = typeArguments[i];
newEnv.put(typeVariables[i], resolveType(currentEnv, typeArgument));
}
return newEnv;
}
private static List<Entry> retrieveGenericInterfaces(Class<?> instanceClass, Map<TypeVariable, Type> environment) {
List<Entry> resultList = new ArrayList<>();
if (instanceClass != null) {
Arrays.stream(instanceClass.getGenericInterfaces())
.map(t -> new Entry(t, environment))
.collect(Collectors.toCollection(() -> resultList));
Class<?> superclass = instanceClass.getSuperclass();
Type genericSuperclass = instanceClass.getGenericSuperclass();
Map<TypeVariable, Type> superClassEnvironment;
if (genericSuperclass instanceof ParameterizedType) {
ParameterizedType parameterizedType = (ParameterizedType) genericSuperclass;
superClassEnvironment = computeParameterizedEnvironment(parameterizedType, environment);
} else {
superClassEnvironment = new HashMap<>();
}
List<Entry> superClassEntries = retrieveGenericInterfaces(superclass, superClassEnvironment);
resultList.addAll(superClassEntries);
}
return resultList;
}
public static Type[] computeGenericTypes(Class<?> instanceClass, Class<?> genericInterface) {
if (instanceClass.isInterface()) {
throw new IllegalArgumentException("The instance class " + instanceClass + " must not be an interface");
} else if (!genericInterface.isInterface()) {
throw new IllegalArgumentException("The generic class " + genericInterface + " must not be an interface");
} else if (instanceClass.getTypeParameters().length > 0) {
throw new IllegalArgumentException("The instance class " + instanceClass + " must not be generic");
}
Stack<Entry> workQueue = new Stack<>();
workQueue.addAll(retrieveGenericInterfaces(instanceClass, Collections.emptyMap()));
while (!workQueue.isEmpty()) {
Entry loopEntry = workQueue.pop();
Type loopType = loopEntry.type;
Map<TypeVariable, Type> loopEnv = loopEntry.environment;
if (loopType instanceof ParameterizedType) {
ParameterizedType loopParameterizedType = (ParameterizedType) loopType;
Type loopRawType = loopParameterizedType.getRawType();
if (genericInterface.equals(loopRawType)) {
Type[] loopTypeArguments = loopParameterizedType.getActualTypeArguments();
Type[] resultTypes = Arrays.stream(loopTypeArguments).map(t -> resolveType(loopEnv, t)).toArray(Type[]::new);
return resultTypes;
} else {
Map<TypeVariable, Type> newEnv = computeParameterizedEnvironment(loopParameterizedType, loopEnv);
workQueue.add(new Entry(loopRawType, newEnv));
}
} else if (loopType instanceof Class) {
Class loopClass = (Class) loopType;
if (genericInterface.equals(loopClass)) {
throw new IllegalArgumentException("The provided class " + instanceClass + " is a raw type of " + genericInterface);
} else {
workQueue.addAll(retrieveGenericInterfaces(loopClass, loopEnv));
}
}
}
throw new IllegalArgumentException("The instance class " + instanceClass + " does not implement " + genericInterface);
}
}
import org.junit.Test;
import java.lang.reflect.Type;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import static org.junit.Assert.assertArrayEquals;
import static GenericResolution.computeGenericTypes;
/**
* Unit tests for GenericResolution.
*
* @author <a href="mailto:shams.imam+github@gmail.com">Shams Imam</a>
*/
public class GenericResolutionTest {
@Test
public void checkPropertiesWithMap() {
final Type[] actualTypes = computeGenericTypes(Properties.class, Map.class);
final Type[] expectedTypes = new Type[]{Object.class, Object.class};
assertArrayEquals(expectedTypes, actualTypes);
}
@Test(expected = IllegalArgumentException.class)
public void checkPropertiesWithAbstractMap() {
final Type[] actualTypes = computeGenericTypes(Properties.class, AbstractMap.class);
final Type[] expectedTypes = new Type[]{Object.class, Object.class};
assertArrayEquals(expectedTypes, actualTypes);
}
@Test(expected = IllegalArgumentException.class)
public void checkJsonFunc1ImplWithList() {
computeGenericTypes(ArrayList.class, JsonFunction.class);
}
@Test(expected = IllegalArgumentException.class)
public void checkObjectWithJsonFunction() {
computeGenericTypes(Object.class, JsonFunction.class);
}
@Test(expected = IllegalArgumentException.class)
public void checkListWithJsonFunction() {
computeGenericTypes(List.class, JsonFunction.class);
}
@Test(expected = IllegalArgumentException.class)
public void checkArrayListWithJsonFunction() {
computeGenericTypes(ArrayList.class, JsonFunction.class);
}
@Test(expected = IllegalArgumentException.class)
public void checkMapWithJsonFunction() {
computeGenericTypes(Map.class, JsonFunction.class);
}
@Test(expected = IllegalArgumentException.class)
public void checkHashMapWithJsonFunction() {
computeGenericTypes(HashMap.class, JsonFunction.class);
}
@Test(expected = IllegalArgumentException.class)
public void checkJsonFunctionWithJsonFunction() {
computeGenericTypes(JsonFunction.class, JsonFunction.class);
}
@Test(expected = IllegalArgumentException.class)
public void checkJsonFunc3IntfWithJsonFunction() {
computeGenericTypes(MyJsonFunction.JsonFunc3Intf.class, JsonFunction.class);
}
@Test(expected = IllegalArgumentException.class)
public void checkJsonFunc4IntfWithJsonFunction() {
computeGenericTypes(MyJsonFunction.JsonFunc4Intf.class, JsonFunction.class);
}
@Test(expected = IllegalArgumentException.class)
public void checkJsonFunc5IntfWithJsonFunction() {
computeGenericTypes(MyJsonFunction.JsonFunc5Intf.class, JsonFunction.class);
}
@Test(expected = IllegalArgumentException.class)
public void checkJsonFunc6IntfWithJsonFunction() {
computeGenericTypes(MyJsonFunction.JsonFunc6Intf.class, JsonFunction.class);
}
@Test(expected = IllegalArgumentException.class)
public void checkJsonFunc7IntfWithJsonFunction() {
computeGenericTypes(MyJsonFunction.JsonFunc7Intf.class, JsonFunction.class);
}
@Test
public void checkJsonFunc1ImplWithJsonFunction() {
final Type[] actualTypes = computeGenericTypes(MyJsonFunction.JsonFunc1Impl.class, JsonFunction.class);
final Type[] expectedTypes = new Type[]{MyJsonFunction.JsonParam.class, MyJsonFunction.JsonReturn.class};
assertArrayEquals(expectedTypes, actualTypes);
}
@Test(expected = IllegalArgumentException.class)
public void checkJsonFunc2ImplWithJsonFunction() {
computeGenericTypes(MyJsonFunction.JsonFunc2Impl.class, JsonFunction.class);
}
@Test
public void checkJsonFunc3ImplWithJsonFunction() {
final Type[] actualTypes = computeGenericTypes(MyJsonFunction.JsonFunc3Impl.class, JsonFunction.class);
final Type[] expectedTypes = new Type[]{MyJsonFunction.JsonParam.class, MyJsonFunction.JsonReturn.class};
assertArrayEquals(expectedTypes, actualTypes);
}
@Test
public void checkJsonFunc4ImplWithJsonFunction() {
final Type[] actualTypes = computeGenericTypes(MyJsonFunction.JsonFunc4Impl.class, JsonFunction.class);
final Type[] expectedTypes = new Type[]{MyJsonFunction.JsonParam.class, MyJsonFunction.JsonParam.class};
assertArrayEquals(expectedTypes, actualTypes);
}
@Test
public void checkJsonFunc5ImplWithJsonFunction() {
final Type[] actualTypes = computeGenericTypes(MyJsonFunction.JsonFunc5Impl.class, JsonFunction.class);
final Type[] expectedTypes = new Type[]{MyJsonFunction.JsonParam.class, MyJsonFunction.JsonReturn.class};
assertArrayEquals(expectedTypes, actualTypes);
}
@Test
public void checkJsonFunc6ImplWithJsonFunction() {
final Type[] actualTypes = computeGenericTypes(MyJsonFunction.JsonFunc6Impl.class, JsonFunction.class);
final Type[] expectedTypes = new Type[]{MyJsonFunction.JsonParam.class, MyJsonFunction.JsonParam.class};
assertArrayEquals(expectedTypes, actualTypes);
}
@Test
public void checkJsonFunc7ImplWithJsonFunction() {
final Type[] actualTypes = computeGenericTypes(MyJsonFunction.JsonFunc7Impl.class, JsonFunction.class);
final Type[] expectedTypes = new Type[]{Map.class, List.class};
assertArrayEquals(expectedTypes, actualTypes);
}
@Test
public void checkJsonFunc7BImplWithJsonFunction() {
final Type[] actualTypes = computeGenericTypes(MyJsonFunction.JsonFunc7BImpl.class, JsonFunction.class);
final Type[] expectedTypes = new Type[]{Map.class, List.class};
assertArrayEquals(expectedTypes, actualTypes);
}
@Test
public void checkJsonFunc8BImplWithJsonFunction() {
final Type[] actualTypes = computeGenericTypes(MyJsonFunction.JsonFunc8BImpl.class, JsonFunction.class);
final Type[] expectedTypes = new Type[]{Float.class, Double.class};
assertArrayEquals(expectedTypes, actualTypes);
}
}
import java.util.List;
import java.util.Map;
/**
* @author <a href="mailto:shams.imam+github@gmail.com">Shams Imam</a>
*/
public class MyJsonFunction {
public static class JsonFunc1Impl implements JsonFunction<MyJsonFunction.JsonParam, MyJsonFunction.JsonReturn> {
@Override
public JsonReturn apply(JsonParam param) {
return new JsonReturn(param.getValue());
}
}
public static class JsonFunc2Impl implements JsonFunction {
@Override
public Object apply(Object param) {
return new JsonReturn(String.valueOf(param));
}
}
public static interface JsonFunc3Intf extends JsonFunction<MyJsonFunction.JsonParam, MyJsonFunction.JsonReturn> {
}
public static class JsonFunc3Impl implements JsonFunc3Intf {
@Override
public JsonReturn apply(JsonParam param) {
return new JsonReturn(param.getValue());
}
}
public static interface JsonFunc4Intf<T> extends JsonFunction<T, T> {
}
public static class JsonFunc4Impl implements JsonFunc4Intf<JsonParam> {
@Override
public JsonParam apply(JsonParam param) {
return param;
}
}
public static interface JsonFunc5Intf<R, P> extends JsonFunction<P, R> {
}
public static class JsonFunc5Impl implements JsonFunc5Intf<MyJsonFunction.JsonReturn, MyJsonFunction.JsonParam> {
@Override
public JsonReturn apply(JsonParam param) {
return new JsonReturn(param.getValue());
}
}
public static interface JsonFunc6Intf<T> extends JsonFunc5Intf<T, T> {
}
public static class JsonFunc6Impl implements JsonFunc6Intf<MyJsonFunction.JsonParam> {
@Override
public JsonParam apply(JsonParam param) {
return param;
}
}
public static interface JsonFunc7Intf<W, X, Y> extends JsonFunc5Intf<List<W>, Map<X, Y>> {
}
public static class JsonFunc7Impl implements JsonFunc7Intf<Integer, String, Double> {
@Override
public List<Integer> apply(Map<String, Double> param) {
return null;
}
}
public static abstract class JsonFunc7AImpl<A, B, C, D> implements JsonFunc7Intf<A, D, Map<B, C>> {
}
public static class JsonFunc7BImpl extends JsonFunc7AImpl<Integer, Double, Float, Long> {
@Override
public List<Integer> apply(Map<Long, Map<Double, Float>> param) {
return null;
}
}
public static interface JsonFunc8AIntf<W, X, Y> extends JsonFunc5Intf<W, Y> {
}
public static interface JsonFunc8BIntf<A, B> extends JsonFunc8AIntf<A, Integer, B> {
}
public static abstract class JsonFunc8AImpl implements JsonFunc8BIntf<Double, Float> {
}
public static class JsonFunc8BImpl extends JsonFunc8AImpl {
@Override
public Double apply(Float param) {
return null;
}
}
public static class JsonParam {
String value;
public JsonParam() {
}
public JsonParam(String value) {
this.value = value;
}
public String getValue() {
return value;
}
public void setValue(String value) {
this.value = value;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof JsonParam)) return false;
JsonParam jsonParam = (JsonParam) o;
return getValue() != null ? getValue().equals(jsonParam.getValue()) : jsonParam.getValue() == null;
}
@Override
public int hashCode() {
return getValue() != null ? getValue().hashCode() : 0;
}
@Override
public String toString() {
return "JsonParam{" +
"value='" + value + '\'' +
'}';
}
}
public static class JsonReturn {
String name;
public JsonReturn() {
}
public JsonReturn(String name) {
this.name = name;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof JsonReturn)) return false;
JsonReturn that = (JsonReturn) o;
return getName() != null ? getName().equals(that.getName()) : that.getName() == null;
}
@Override
public int hashCode() {
return getName() != null ? getName().hashCode() : 0;
}
@Override
public String toString() {
return "JsonReturn{" +
"name='" + name + '\'' +
'}';
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment