Skip to content

Instantly share code, notes, and snippets.

@samtgarson
Created March 19, 2024 08:57
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save samtgarson/f1a84096aa53786e9281bd36d4e2dc0a to your computer and use it in GitHub Desktop.
Save samtgarson/f1a84096aa53786e9281bd36d4e2dc0a to your computer and use it in GitHub Desktop.
Custom Safe Reference type in Mobx State Tree
import { types } from 'mobx-state-tree'
import { getModelType, getModelTypes } from './get-model-type'
const TestModel = types.model('models', {
id: types.identifier,
name: types.literal('frodo'),
})
const TestModel2 = types.model('other-models', {
id: types.identifier,
name: types.literal('samwise'),
})
describe('getModelType', () => {
describe('when the model is a union type', () => {
it('returns the model type', () => {
const model = types.union(TestModel, TestModel2)
expect(getModelType(model)).toEqual(TestModel)
})
})
describe('when the model is a nested union type', () => {
it('returns the model type', () => {
const union = types.union(TestModel, TestModel2)
const union2 = types.union(TestModel, TestModel2)
const model = types.union(union, union2)
expect(getModelType(model)).toEqual(TestModel)
})
})
describe('when the model is a model type', () => {
it('returns the model type', () => {
const model = TestModel
expect(getModelType(model)).toEqual(TestModel)
})
})
describe('throws an error if the type is not a model type', () => {
it('returns the model type', () => {
const model = types.string
expect(() => getModelType(model)).toThrowError('Unsupported model type')
})
})
describe('when the model is a late type', () => {
it('returns the model type', () => {
const model = types.late(() => TestModel)
expect(getModelType(model)).toEqual(TestModel)
})
})
})
describe('getModelTypes', () => {
describe('when the model is a union type', () => {
it('returns the model types', () => {
const model = types.union(TestModel, TestModel2)
expect(getModelTypes(model)).toEqual([TestModel, TestModel2])
})
})
describe('when the model is a nested union type', () => {
it('returns the model types', () => {
const union = types.union(TestModel, TestModel2)
const union2 = types.union(TestModel, TestModel2)
const model = types.union(union, union2)
expect(getModelTypes(model)).toEqual([TestModel, TestModel2])
})
})
describe('when the model is a model type', () => {
it('returns the model type', () => {
const model = TestModel
expect(getModelTypes(model)).toEqual([TestModel])
})
})
describe('throws an error if the type is not a model type', () => {
it('returns the model type', () => {
const model = types.string
expect(() => getModelTypes(model)).toThrowError('Unsupported model type')
})
})
})
import {
IAnyModelType,
IAnyType,
isArrayType,
isLateType,
isModelType,
isOptionalType,
isReferenceType,
isUnionType,
types,
} from 'mobx-state-tree'
/**
* A utility function to extract the underlying model from complex types.
*
* E.g. extracts models from unions, arrays, optional and late types, etc
**/
export const getModelType = (model: IAnyType): IAnyModelType => {
const modelTypes = getModelTypes(model)
return modelTypes[0]
}
export const getModelTypes = (model: IAnyType): IAnyModelType[] => {
const found = decodeTypes(model).filter(isModelType)
if (found.length > 0) return found
throw new Error(`Unsupported model type: ${model.name}`)
}
export const decodeTypes = (model: IAnyType | IAnyType[]): IAnyType[] => {
if (Array.isArray(model)) {
return filter(model.flatMap((type) => decodeTypes(type)))
}
if (isOptionalType(model) as boolean) {
// @ts-expect-error `getSubTypes` is an untyped, internal method.
return decodeTypes(model.getSubTypes())
}
if (isUnionType(model) as boolean) {
// @ts-expect-error `_types` is an untyped, internal property.
return decodeTypes(model._types)
}
if (isArrayType(model) as boolean) {
// @ts-expect-error `_subType` is an untyped, internal property.
return decodeTypes(model._subType)
}
if (isReferenceType(model) as boolean) {
// @ts-expect-error `targetType` is an untyped, internal property.
return decodeTypes(model.targetType)
}
if (isLateType(model) as boolean) {
// @ts-expect-error `_definition` an untyped, internal method.
return decodeTypes(model._definition())
}
if (types.null.is(model)) return []
return [model]
}
export const filter = (typeArray: IAnyType[]): IAnyType[] => {
return typeArray.filter((item, index, arr) => {
if (types.null.is(item)) return false // Remove null types
return arr.indexOf(item) === index // Deduplicate types
})
}
import { Instance, types, unprotect } from 'mobx-state-tree'
import { reference } from './reference'
describe('saferReference', () => {
let store: Instance<typeof StoreDefinition>
let otherStore: Instance<typeof OtherStoreDefinition>
let thirdStore: Instance<typeof ThirdStoreDefinition>
const OtherModel = types.model('OtherTestModule', {
id: types.identifier,
url: types.string,
})
const ThirdModel = types.model('ThirdTestModule', {
id: types.identifier,
path: types.string,
})
const Model = types.model('TestModule', {
id: types.identifier,
optional: reference(OtherModel),
required: reference(OtherModel, { required: true }),
polymorphic: reference(types.union(OtherModel, ThirdModel)),
})
const StoreDefinition = types.model('ModelStore', {
data: types.map(Model),
})
const OtherStoreDefinition = types.model('OtherModelStore', {
data: types.map(OtherModel),
})
const ThirdStoreDefinition = types.model('ThirdModelStore', {
data: types.map(ThirdModel),
})
const RootStoreDefinition = types.model('RootStore', {
testModules: types.optional(StoreDefinition, {}),
otherTestModules: types.optional(OtherStoreDefinition, {}),
thirdTestModules: types.optional(ThirdStoreDefinition, {}),
})
beforeEach(() => {
const root = RootStoreDefinition.create({})
store = root['testModules']
otherStore = root['otherTestModules']
thirdStore = root['thirdTestModules']
unprotect(root)
})
describe('with an ID-only reference [deprecated]', () => {
it('nullifies missing relationships', async () => {
store.data.set('1', { id: '1', optional: '1', required: '1' })
expect(store.data.get('1')?.optional).toBeUndefined()
const model = { id: '1', url: 'google' }
otherStore.data.set('1', model)
expect(store.data.get('1')?.optional).toEqual(model)
})
})
describe('with a reference containing a type', () => {
it('resolves the reference containing a type', async () => {
store.data.set('1', {
id: '1',
optional: '1|otherTestModules',
required: '1|otherTestModules',
})
const model = { id: '1', url: 'google' }
otherStore.data.set('1', model)
expect(store.data.get('1')?.optional).toEqual(model)
})
it('nullifies missing relationships', async () => {
store.data.set('1', {
id: '1',
optional: '1|otherTestModules',
required: '1|otherTestModules',
})
expect(store.data.get('1')?.optional).toBeUndefined()
const model = { id: '1', url: 'google' }
otherStore.data.set('1', model)
expect(store.data.get('1')?.optional).toEqual(model)
})
it('handles required relationships', async () => {
store.data.set('1', {
id: '1',
optional: '1',
required: '1',
})
expect(() => {
store.data.get('1')?.required
}).toThrowError()
})
})
describe('with an ambiguous polymorphic reference', () => {
it('resolves the reference to the correct type', async () => {
store.data.set('1', {
id: '1',
optional: '1',
required: '1',
polymorphic: '1|thirdTestModules',
})
const model = { id: '1', url: 'google' }
otherStore.data.set('1', model)
const thirdModel = { id: '1', path: 'google' }
thirdStore.data.set('1', thirdModel)
expect(store.data.get('1')?.polymorphic).toEqual(thirdModel)
})
})
})
import {
getRoot,
IMaybeNull,
IReferenceType,
resolveIdentifier,
types,
ReferenceOptions,
resolvePath,
IAnyStateTreeNode,
IAnyModelType,
ITypeUnion,
} from 'mobx-state-tree'
import { getModelTypes } from './get-model-type'
type ReferenceableType = IAnyModelType | ITypeUnion<any, any, any>
/**
* A reference type with a custom resolver to handle polymorphic and missing relationships
*
* @remarks
* Mobx state tree doesn't support missing relationships upon initialization,
* only relationships which become missing after initialization. This type
* safely attempts to resolve the relationship without throwing an error,
* and types the property as potentially undefined.
*
* This custom type allows us to always use the same API (this function)
* for references, and provide a `required` flag to specify whether the
* relationship will always be there, or whether it should typed as optional.
**/
export function reference<Model extends ReferenceableType>(
model: Model,
options: { required: true } & Partial<ReferenceOptions<Model>>
): IReferenceType<Model>
export function reference<Model extends ReferenceableType>(
model: Model,
options?: { required?: false } & Partial<ReferenceOptions<Model>>
): IMaybeNull<IReferenceType<Model>>
export function reference<Model extends ReferenceableType>(
model: Model,
{
required,
...options
}: { required?: boolean } & Partial<ReferenceOptions<Model>> = {
required: false,
}
) {
// eslint-disable-next-line local-rules/disallow-types-reference
const referenceType = types.reference(model, {
get(identifier, parent) {
const root = getRoot(parent)
const [id, type] = `${identifier}`.split('|')
const types = getModelTypes(model)
const found = resolveModels(id, type, root, types)
if (found || !required) return found
throw new Error(
`Could not resolve model ${model.name} ${
type ? `of type \`${type}\`` : ''
} with id ${id}`
)
},
set(value) {
return value.id
},
...options,
})
if (required) return referenceType
return types.maybeNull(referenceType)
}
const resolveModels = (
id: string,
type: string,
root: IAnyStateTreeNode,
model: IAnyModelType[]
) => {
for (const m of model) {
const found = resolveModel(id, type, root, m)
if (found) return found
}
return undefined
}
const resolveModel = (
id: string,
type: string,
root: IAnyStateTreeNode,
model: IAnyModelType
) => {
const found = type
? resolveWithType(id, type, root, model)
: resolveIdentifier(model, root, id) // deprecated
return found
}
const resolveWithType = (
id: string,
type: string,
root: IAnyStateTreeNode,
model: IAnyModelType
) => {
try {
const found = resolvePath(root, `/${type}/data/${id}`)
if (found && model.is(found)) return found
return undefined
} catch (error) {
return undefined
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment