Skip to content

Instantly share code, notes, and snippets.

@jaens
Last active July 3, 2024 22:47
Show Gist options
  • Save jaens/7e15ae1984bb338c86eb5e452dee3010 to your computer and use it in GitHub Desktop.
Save jaens/7e15ae1984bb338c86eb5e452dee3010 to your computer and use it in GitHub Desktop.
Zod deep strict and map utility
import { z, type ZodDiscriminatedUnionOption } from "zod";
const RESOLVING = Symbol("mapOnSchema/resolving");
export function mapOnSchema<T extends z.ZodTypeAny, TResult extends z.ZodTypeAny>(
schema: T,
fn: (schema: z.ZodTypeAny) => TResult,
): TResult;
export function mapOnSchema(schema: z.ZodTypeAny, fn: (schema: z.ZodTypeAny) => z.ZodTypeAny): z.ZodTypeAny {
// Cache results to support recursive schemas
const results = new Map<z.ZodTypeAny, z.ZodTypeAny | typeof RESOLVING>();
function mapElement(s: z.ZodTypeAny) {
const value = results.get(s);
if (value === RESOLVING) {
throw new Error("Recursive schema access detected");
} else if (value !== undefined) {
return value;
}
results.set(s, RESOLVING);
const result = mapOnSchema(s, fn);
results.set(s, result);
return result;
}
function mapInner() {
if (schema instanceof z.ZodObject) {
const newShape: Record<string, z.ZodTypeAny> = {};
for (const [key, value] of Object.entries(schema.shape)) {
newShape[key] = mapElement(value as z.ZodTypeAny);
}
return new z.ZodObject({
...schema._def,
shape: () => newShape,
});
} else if (schema instanceof z.ZodArray) {
return new z.ZodArray({
...schema._def,
type: mapElement(schema._def.type),
});
} else if (schema instanceof z.ZodMap) {
return new z.ZodMap({
...schema._def,
keyType: mapElement(schema._def.keyType),
valueType: mapElement(schema._def.valueType),
});
} else if (schema instanceof z.ZodSet) {
return new z.ZodSet({
...schema._def,
valueType: mapElement(schema._def.valueType),
});
} else if (schema instanceof z.ZodOptional) {
return new z.ZodOptional({
...schema._def,
innerType: mapElement(schema._def.innerType),
});
} else if (schema instanceof z.ZodNullable) {
return new z.ZodNullable({
...schema._def,
innerType: mapElement(schema._def.innerType),
});
} else if (schema instanceof z.ZodDefault) {
return new z.ZodDefault({
...schema._def,
innerType: mapElement(schema._def.innerType),
});
} else if (schema instanceof z.ZodReadonly) {
return new z.ZodReadonly({
...schema._def,
innerType: mapElement(schema._def.innerType),
});
} else if (schema instanceof z.ZodLazy) {
return new z.ZodLazy({
...schema._def,
// NB: This leaks `fn` into the schema, but there is no other way to support recursive schemas
getter: () => mapElement(schema._def.getter()),
});
} else if (schema instanceof z.ZodBranded) {
return new z.ZodBranded({
...schema._def,
type: mapElement(schema._def.type),
});
} else if (schema instanceof z.ZodEffects) {
return new z.ZodEffects({
...schema._def,
schema: mapElement(schema._def.schema),
});
} else if (schema instanceof z.ZodFunction) {
return new z.ZodFunction({
...schema._def,
args: schema._def.args.map((arg: z.ZodTypeAny) => mapElement(arg)),
returns: mapElement(schema._def.returns),
});
} else if (schema instanceof z.ZodPromise) {
return new z.ZodPromise({
...schema._def,
type: mapElement(schema._def.type),
});
} else if (schema instanceof z.ZodCatch) {
return new z.ZodCatch({
...schema._def,
innerType: mapElement(schema._def.innerType),
});
} else if (schema instanceof z.ZodTuple) {
return new z.ZodTuple({
...schema._def,
items: schema._def.items.map((item: z.ZodTypeAny) => mapElement(item)),
rest: schema._def.rest && mapElement(schema._def.rest),
});
} else if (schema instanceof z.ZodDiscriminatedUnion) {
const optionsMap = new Map(
[...schema.optionsMap.entries()].map(([k, v]) => [
k,
mapElement(v) as ZodDiscriminatedUnionOption<string>,
]),
);
return new z.ZodDiscriminatedUnion({
...schema._def,
options: [...optionsMap.values()],
optionsMap: optionsMap,
});
} else if (schema instanceof z.ZodUnion) {
return new z.ZodUnion({
...schema._def,
options: schema._def.options.map((option: z.ZodTypeAny) => mapElement(option)),
});
} else if (schema instanceof z.ZodIntersection) {
return new z.ZodIntersection({
...schema._def,
right: mapElement(schema._def.right),
left: mapElement(schema._def.left),
});
} else if (schema instanceof z.ZodRecord) {
return new z.ZodRecord({
...schema._def,
keyType: mapElement(schema._def.keyType),
valueType: mapElement(schema._def.valueType),
});
} else {
return schema;
}
}
return fn(mapInner());
}
/** Make all object schemas "strict" (ie. fail on unknown keys), except if they are marked as `.passthrough()` */
export function deepStrict<T extends z.ZodTypeAny>(schema: T): T {
return mapOnSchema(schema, (s) =>
s instanceof z.ZodObject && s._def.unknownKeys !== "passthrough" ? s.strict() : s,
) as T;
}
export function deepStrictAll<T extends z.ZodTypeAny>(schema: T): T {
return mapOnSchema(schema, (s) => (s instanceof z.ZodObject ? s.strict() : s)) as T;
}
@kernwig
Copy link

kernwig commented Apr 1, 2024

@jaens I did indeed need it for an applications where I dynamically add a superRefine to some string properties that may have a default value, or a string property on an object inside of an array that defaults to empty.

Aside: Because my callback fn needs to call an API to fetch the list of valid string values, I had to implement an async version of mapOnSchema.

@kernwig
Copy link

kernwig commented Apr 29, 2024

Just discovered that this function is why my object properties loose their z.describe() values when I use zod-to-openapi to create OpenAPI.

I fixed by adding { ...schema._def } as the second parameter to all of the z.Zod<thing>.create() function calls, thus preserving not only the description but errorMap, invalid_type_error, and required_error.

@jaens
Copy link
Author

jaens commented Jul 3, 2024

I updated the code to account for all known flaws. The usage of .create() was removed to be more "extension-proof".

@jaens
Copy link
Author

jaens commented Jul 3, 2024

Aside: Because my callback fn needs to call an API to fetch the list of valid string values, I had to implement an async version of mapOnSchema.

I mean, technically might not be necessary, it can be done in three passes (since the mapping order is deterministic, unless using lazy schemas):

  1. Use mapOnSchema with the identity function, but as a side-effect, push all required requests into an eg. array (as promises).
  2. Await all the promises in the array, into a new resolved one.
  3. Run mapOnSchema again, this time reading the resolved values from the array (with an incrementing index).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment