Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
diff --git a/graph/src/data/graphql/validation.rs b/graph/src/data/graphql/validation.rs
index 19e677be..1a4953b5 100644
--- a/graph/src/data/graphql/validation.rs
+++ b/graph/src/data/graphql/validation.rs
@@ -146,8 +146,9 @@ fn validate_derived_from(schema: &Document) -> Result<(), SchemaValidationError>
let object_and_interface_type_fields = get_object_and_interface_type_fields(schema);
// Iterate over all derived fields in all entity types; include the
- // `field` argument of @derivedFrom directive
- for (object_type, field, target_field) in type_definitions
+ // interface types that the entity with the `@derivedFrom` implements
+ // and the `field` argument of @derivedFrom directive
+ for (object_type, interface_types, field, target_field) in type_definitions
.clone()
.iter()
.flat_map(|object_type| {
@@ -164,6 +165,15 @@ fn validate_derived_from(schema: &Document) -> Result<(), SchemaValidationError>
.map(|directive| {
(
object_type,
+ object_type
+ .implements_interfaces
+ .iter()
+ .filter(|iface| {
+ // FIXME: Filter interface by whether they have the same field with
+ // a `@derivedFrom` directive
+ true
+ })
+ .collect::<Vec<_>>(),
field,
directive
.arguments
@@ -223,12 +233,34 @@ fn validate_derived_from(schema: &Document) -> Result<(), SchemaValidationError>
// For that, we will wind up comparing the `id`s of the two types
// when we query, and just assume that that's ok.
let target_field_type = get_base_type(&target_field.field_type);
- if target_field_type != &object_type.name && target_field_type != "ID" {
+ if target_field_type != &object_type.name
+ && target_field_type != "ID"
+ && !interface_types
+ .iter()
+ .any(|iface| &target_field_type == iface)
+ {
+ fn type_signatures(name: &String) -> Vec<String> {
+ vec![
+ format!("{}", name),
+ format!("{}!", name),
+ format!("[{}!]", name),
+ format!("[{}!]!", name),
+ ]
+ };
+
+ let mut valid_types = type_signatures(&object_type.name);
+ valid_types.extend(
+ interface_types
+ .iter()
+ .flat_map(|iface| type_signatures(iface)),
+ );
+ let valid_types = valid_types.join(", ");
+
let msg = format!(
- "field `{tf}` on type `{tt}` must have type `{ot}`, `{ot}!`, or `[{ot}!]!`",
+ "field `{tf}` on type `{tt}` must have one of the following types: {valid_types}",
tf = target_field.name,
tt = target_type_name,
- ot = object_type.name,
+ valid_types = valid_types,
);
return Err(invalid(object_type, &field.name, &msg));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment