Skip to content

Instantly share code, notes, and snippets.

@jedwards1211
Last active June 3, 2020 12:00
Show Gist options
  • Save jedwards1211/684500774eb2ff76dbb21bb9d9500716 to your computer and use it in GitHub Desktop.
Save jedwards1211/684500774eb2ff76dbb21bb9d9500716 to your computer and use it in GitHub Desktop.
useConnection (a React hook for infinite scrolling with Apollo, Relay-style Connections, and react-virtualized)
/**
* @flow
* @prettier
*/
import { type QueryRenderProps } from 'react-apollo'
import * as React from 'react'
import {
get,
takeRightWhile,
takeWhile,
filter,
map,
fromPairs,
pick,
mapKeys,
last,
} from 'lodash/fp'
import updateIn from '@jcoreio/mutate/updateIn'
import * as graphql from 'graphql'
import { pipeline } from '../util/pipeline'
import useStash from '../react-hooks/useStash'
export type GridProps = {
onRowsRendered: Function,
rowCount: number,
getNode: (index: number) => ?Object,
}/**
* @flow
* @prettier
*/
import { type QueryRenderProps } from 'react-apollo'
import * as React from 'react'
import {
get,
takeRightWhile,
takeWhile,
filter,
map,
fromPairs,
pick,
mapKeys,
last,
} from 'lodash/fp'
import updateIn from '@jcoreio/mutate/updateIn'
import * as graphql from 'graphql'
import { pipeline } from '../util/pipeline'
import useStash from '../react-hooks/useStash'
export type GridProps = {
onRowsRendered: Function,
rowCount: number,
getNode: (index: number) => ?Object,
}
type RenderedRange = {|
+startIndex: number,
+stopIndex: number,
+overscanStartIndex: number,
+overscanStopIndex: number,
|}
type Connection = {
+__typename: string,
+edges: $ReadOnlyArray<{
+__typename: string,
+cursor: string,
+node: ?Object,
}>,
+pageInfo: {
+__typename: string,
+startCursor: ?string,
+endCursor: ?string,
+hasPreviousPage: boolean,
+hasNextPage: boolean,
},
}
/**
* Provides props to turn a react-virtualized List/Table/Grid into an infinite
* scroll backed by a [Relay Cursor Connection](https://facebook.github.io/relay/graphql/connections.htm)
* from GraphQL.
*
* You provide the GraphQL query and this will automatically figure out where
* the connection field is. There must only be one connection in the query
* and the query should have at minimum (names within [ ] may vary):
*
* query (
* $[first]: Int
* $[after]: String
* $[last]: Int
* $[before]: String
* ) {
* [connection](
* first: $[first]
* after: $[after]
* last: $[last]
* before: $[before]
* ) {
* edges {
* cursor
* node {
* ...
* }
* }
* pageInfo {
* startCursor
* endCursor
* hasPreviousPage
* hasNextPage
* }
* }
* }
*
* You should use the returned getNode function to get the node at the given
* row, because an offset may be applied to keep rows in the same position
* if edges get evicted from the head of the connection because the user scrolls
* way down.
*/
export default function useConnection(
queryRenderProps: QueryRenderProps<any, any>,
options: {
query: graphql.DocumentNode,
minFetchSize: number,
maxFetchSize?: number,
maxNumEdgesToKeep: number,
convertNode?: ?(Object) => Object,
}
): GridProps {
const { data, fetchMore, refetch } = queryRenderProps
const { query, minFetchSize, maxNumEdgesToKeep, convertNode } = options
const { connectionPath, variableNames } = React.useMemo(
() => analyzeQuery(query),
[query]
)
const maxFetchSize = Math.min(
maxNumEdgesToKeep,
options.maxFetchSize || minFetchSize
)
const clampFetchSize = FetchSize =>
Math.max(minFetchSize, Math.min(maxFetchSize, FetchSize))
const [renderedRows, setRenderedRows] = React.useState<RenderedRange>({
overscanStartIndex: 0,
overscanStopIndex: 0,
startIndex: 0,
stopIndex: 0,
})
const [fetching, setFetching] = React.useState(false)
const stash: {
connection?: ?Connection,
numRowsBefore: number,
} = useStash()
const prevConnection = stash.connection
const connection = get(connectionPath)(data)
stash.connection = connection
// numRowsBefore is the offset of the first row in the connection.
// E.g. if we evict the first 10 rows, we add 10 to numRowsBefore so that the
// rest of the rows still show in the same position, and we can reload
// those first 10 rows when the user scrolls back to them.
// If no rows at the head are evicted but there's a previous page, we use
// numRowsBefore = 1 so that the list will show a loading row at the top
// and it will trigger a fetch when visible.
if (connection) {
if (prevConnection) {
stash.numRowsBefore = connection.pageInfo.hasPreviousPage
? Math.max(
1,
stash.numRowsBefore + getStartOffset(prevConnection, connection)
)
: 0
} else {
stash.numRowsBefore = connection.pageInfo.hasPreviousPage ? 1 : 0
}
} else {
stash.numRowsBefore = 0
}
const { numRowsBefore } = stash
React.useEffect(() => {
if (fetching || !connection) return
const {
startIndex,
stopIndex,
overscanStartIndex,
overscanStopIndex,
} = renderedRows
const {
pageInfo: { hasPreviousPage, hasNextPage, startCursor, endCursor },
edges,
} = connection
const done = () => setFetching(false)
// if the user scrolls all the way back to the top and rows are missing
// there, just refetch the head of the list
if (startIndex === 0 && numRowsBefore > 0) {
setFetching(true)
// I had been putting first: maxFetchSize in the variables here, but
// unfortunately that puts the results under a different apollo cache
// key instead of replacing the data in the existing cache key.
refetch().then(done, done)
return
}
// first see if rows in the visible range need to be fetched
// otherwise, see if any rows in the overscan range need to be fetched
let fetchSide
if (startIndex < numRowsBefore && hasPreviousPage) {
fetchSide = 'before'
} else if (stopIndex >= numRowsBefore + edges.length && hasNextPage) {
fetchSide = 'after'
} else if (overscanStartIndex < numRowsBefore && hasPreviousPage) {
fetchSide = 'before'
} else if (
overscanStopIndex >= numRowsBefore + edges.length &&
hasNextPage
) {
fetchSide = 'after'
} else {
return
}
setFetching(true)
const fetchVariables =
fetchSide === 'before'
? {
last: clampFetchSize(numRowsBefore - overscanStartIndex),
before: startCursor,
first: null,
after: null,
}
: {
first: clampFetchSize(
overscanStopIndex + 1 - numRowsBefore - edges.length
),
after: endCursor,
last: null,
before: null,
}
fetchMore({
variables: mapKeys(v => variableNames[v])(fetchVariables),
updateQuery: (prev: Object, { fetchMoreResult, variables }) =>
updateIn(prev, connectionPath, (connection: ?Connection): Connection =>
mergeConnections({
connection,
more: get(connectionPath)(fetchMoreResult),
side: fetchSide,
renderedRows,
maxNumEdgesToKeep,
})
),
}).then(done, done)
}, [
queryRenderProps,
renderedRows,
numRowsBefore,
fetching,
minFetchSize,
maxFetchSize,
])
const rowCount = React.useMemo((): number => {
if (!connection) return 1
return (
numRowsBefore +
connection.edges.length +
(connection.pageInfo.hasNextPage ? 1 : 0)
)
}, [connection, numRowsBefore])
const getNode = React.useCallback(
(index: number): ?Object => {
const node = connection?.edges[index - numRowsBefore]?.node
return node && convertNode ? convertNode(node) : node
},
[connection, numRowsBefore, convertNode]
)
return {
rowCount,
getNode,
onRowsRendered: setRenderedRows,
}
}
function getField(
node: graphql.FieldNode | graphql.OperationDefinitionNode,
name: string
): ?graphql.FieldNode {
if (!node.selectionSet) return null
const { selections } = node.selectionSet
return (selections.find(
n => n.kind === 'Field' && n.name.value === name
): any)
}
function analyzeQuery(
query: graphql.DocumentNode
): {
connectionPath: Array<any>,
variableNames: {
first: string,
after: string,
last: string,
before: string,
},
} {
const queryOp: ?graphql.OperationDefinitionNode = (query.definitions.find(
d => d.kind === 'OperationDefinition' && d.operation === 'query'
): any)
if (!queryOp) throw new Error(`failed to find query OperationDefinition`)
const {
selectionSet: { selections: querySelections },
} = queryOp
const connectionField: ?graphql.FieldNode = (querySelections.find(
s => s.kind === 'Field' && getField(s, 'pageInfo')
): any)
if (!connectionField) {
throw new Error(`failed to find connection field`)
}
const edgesField = getField(connectionField, 'edges')
if (!edgesField) throw new Error(`failed to find edges field`)
for (const name of ['cursor', 'node']) {
if (!getField(edgesField, name))
throw new Error(`missing edges.${name} field`)
}
const pageInfoField = getField(connectionField, 'pageInfo')
if (!pageInfoField) throw new Error(`failed to find pageInfo field`)
for (const name of [
'startCursor',
'endCursor',
'hasNextPage',
'hasPreviousPage',
]) {
if (!getField(pageInfoField, name))
throw new Error(`missing pageInfo.${name} field`)
}
const { alias, name } = connectionField
const connectionVars = ['first', 'after', 'last', 'before']
const variableNames = pipeline(
connectionField.arguments,
filter(a => a.value.kind === 'Variable'),
map(a => [a.name.value, a.value.name.value]),
fromPairs,
pick(connectionVars)
)
for (const name of connectionVars) {
if (typeof variableNames[name] !== 'string') {
throw new Error(`failed to find variable for ${name}`)
}
}
const connectionPath = [(alias || name).value]
return { connectionPath, variableNames }
}
/**
* This determines how many rows have been added or removed from the head of
* the connection. This can happen when:
* - rows get evicted from the head because the user scrolls far down
* - evicted rows at the head get refetched when the user scrolls back up
* - external operations like subscriptions add/remove rows
*/
function getStartOffset(
prevConnection: Connection,
nextConnection: Connection
): number {
const numAddedBefore = nextConnection.edges.findIndex(
e => e.cursor === prevConnection.pageInfo.startCursor
)
if (numAddedBefore >= 0) return -numAddedBefore
else {
const numDeletedBefore = prevConnection.edges.findIndex(
e => e.cursor === nextConnection.pageInfo.startCursor
)
if (numDeletedBefore >= 0) return numDeletedBefore
}
return 0
}
function mergeConnections({
connection,
more,
side,
renderedRows,
maxNumEdgesToKeep,
}: {
connection: ?Connection,
more: Connection,
side: 'before' | 'after',
renderedRows: RenderedRange,
maxNumEdgesToKeep: number,
}): Connection {
const {
__typename,
pageInfo: { __typename: pageInfoTypename },
} = more
if (!connection) {
connection = {
__typename,
edges: [],
pageInfo: {
__typename: pageInfoTypename,
hasPreviousPage: true,
hasNextPage: true,
startCursor: null,
endCursor: null,
},
}
}
if (side === 'after') {
let edges = [
...takeWhile(e => e.cursor !== more.pageInfo.startCursor)(
connection.edges
),
...more.edges,
]
if (edges.length > maxNumEdgesToKeep) {
let start = edges.length - maxNumEdgesToKeep
let end = edges.length
if (start > renderedRows.startIndex) {
start = renderedRows.startIndex
end = Math.max(
renderedRows.stopIndex + 1,
end - start + renderedRows.startIndex
)
}
edges = edges.slice(start, end)
}
const { startCursor, hasPreviousPage } = connection.pageInfo
return {
...more,
edges,
pageInfo: {
...more.pageInfo,
startCursor: edges[0]?.cursor,
hasPreviousPage:
edges[0]?.cursor === startCursor ? hasPreviousPage : true,
},
}
} else {
let edges = [
...more.edges,
...takeRightWhile(e => e.cursor !== more.pageInfo.endCursor)(
connection.edges
),
]
if (edges.length > maxNumEdgesToKeep) {
let start = 0
let end = maxNumEdgesToKeep
if (end <= renderedRows.stopIndex) {
start = Math.min(
renderedRows.startIndex,
start + renderedRows.stopIndex - end + 1
)
end = renderedRows.stopIndex
}
edges = edges.slice(start, end)
}
const { endCursor, hasNextPage } = connection.pageInfo
return {
...more,
edges,
pageInfo: {
...more.pageInfo,
endCursor: last(edges)?.cursor,
hasNextPage: last(edges)?.cursor === endCursor ? hasNextPage : true,
},
}
}
}
type RenderedRange = {|
+startIndex: number,
+stopIndex: number,
+overscanStartIndex: number,
+overscanStopIndex: number,
|}
type Connection = {
+__typename: string,
+edges: $ReadOnlyArray<{
+__typename: string,
+cursor: string,
+node: ?Object,
}>,
+pageInfo: {
+__typename: string,
+startCursor: ?string,
+endCursor: ?string,
+hasPreviousPage: boolean,
+hasNextPage: boolean,
},
}
/**
* Provides props to turn a react-virtualized List/Table/Grid into an infinite
* scroll backed by a [Relay Cursor Connection](https://facebook.github.io/relay/graphql/connections.htm)
* from GraphQL.
*
* You provide the GraphQL query and this will automatically figure out where
* the connection field is. There must only be one connection in the query
* and the query should have at minimum (names within [ ] may vary):
*
* query (
* $[first]: Int
* $[after]: String
* $[last]: Int
* $[before]: String
* ) {
* [connection](
* first: $[first]
* after: $[after]
* last: $[last]
* before: $[before]
* ) {
* edges {
* cursor
* node {
* ...
* }
* }
* pageInfo {
* startCursor
* endCursor
* hasPreviousPage
* hasNextPage
* }
* }
* }
*
* You should use the returned getNode function to get the node at the given
* row, because an offset may be applied to keep rows in the same position
* if edges get evicted from the head of the connection because the user scrolls
* way down.
*/
export default function useConnection(
queryRenderProps: QueryRenderProps<any, any>,
options: {
query: graphql.DocumentNode,
minFetchSize: number,
maxFetchSize?: number,
maxNumEdgesToKeep: number,
convertNode?: ?(Object) => Object,
}
): GridProps {
const { data, fetchMore, refetch } = queryRenderProps
const { query, minFetchSize, maxNumEdgesToKeep, convertNode } = options
const { connectionPath, variableNames } = React.useMemo(
() => analyzeQuery(query),
[query]
)
const maxFetchSize = Math.min(
maxNumEdgesToKeep,
options.maxFetchSize || minFetchSize
)
const clampFetchSize = FetchSize =>
Math.max(minFetchSize, Math.min(maxFetchSize, FetchSize))
const [renderedRows, setRenderedRows] = React.useState<RenderedRange>({
overscanStartIndex: 0,
overscanStopIndex: 0,
startIndex: 0,
stopIndex: 0,
})
const [fetching, setFetching] = React.useState(false)
const stash: {
connection?: ?Connection,
numRowsBefore: number,
} = useStash()
const prevConnection = stash.connection
const connection = get(connectionPath)(data)
stash.connection = connection
// numRowsBefore is the offset of the first row in the connection.
// E.g. if we evict the first 10 rows, we add 10 to numRowsBefore so that the
// rest of the rows still show in the same position, and we can reload
// those first 10 rows when the user scrolls back to them.
// If no rows at the head are evicted but there's a previous page, we use
// numRowsBefore = 1 so that the list will show a loading row at the top
// and it will trigger a fetch when visible.
if (connection) {
if (prevConnection) {
stash.numRowsBefore = connection.pageInfo.hasPreviousPage
? Math.max(
1,
stash.numRowsBefore + getStartOffset(prevConnection, connection)
)
: 0
} else {
stash.numRowsBefore = connection.pageInfo.hasPreviousPage ? 1 : 0
}
} else {
stash.numRowsBefore = 0
}
const { numRowsBefore } = stash
React.useEffect(() => {
if (fetching || !connection) return
const {
startIndex,
stopIndex,
overscanStartIndex,
overscanStopIndex,
} = renderedRows
const {
pageInfo: { hasPreviousPage, hasNextPage, startCursor, endCursor },
edges,
} = connection
const done = () => setFetching(false)
// if the user scrolls all the way back to the top and rows are missing
// there, just refetch the head of the list
if (startIndex === 0 && numRowsBefore > 0) {
setFetching(true)
refetch({
first: maxFetchSize,
after: null,
last: null,
before: null,
}).then(done, done)
return
}
// first see if rows in the visible range need to be fetched
// otherwise, see if any rows in the overscan range need to be fetched
let fetchSide
if (startIndex < numRowsBefore && hasPreviousPage) {
fetchSide = 'before'
} else if (stopIndex >= numRowsBefore + edges.length && hasNextPage) {
fetchSide = 'after'
} else if (overscanStartIndex < numRowsBefore && hasPreviousPage) {
fetchSide = 'before'
} else if (
overscanStopIndex >= numRowsBefore + edges.length &&
hasNextPage
) {
fetchSide = 'after'
} else {
return
}
setFetching(true)
const fetchVariables =
fetchSide === 'before'
? {
last: clampFetchSize(numRowsBefore - overscanStartIndex),
before: startCursor,
first: null,
after: null,
}
: {
first: clampFetchSize(
overscanStopIndex + 1 - numRowsBefore - edges.length
),
after: endCursor,
last: null,
before: null,
}
fetchMore({
variables: mapKeys(v => variableNames[v])(fetchVariables),
updateQuery: (prev: Object, { fetchMoreResult, variables }) =>
updateIn(prev, connectionPath, (connection: ?Connection): Connection =>
mergeConnections({
connection,
more: get(connectionPath)(fetchMoreResult),
side: fetchSide,
maxNumEdgesToKeep,
})
),
}).then(done, done)
}, [
queryRenderProps,
renderedRows,
numRowsBefore,
fetching,
minFetchSize,
maxFetchSize,
])
const rowCount = React.useMemo((): number => {
if (!connection) return 1
return (
numRowsBefore +
connection.edges.length +
(connection.pageInfo.hasNextPage ? 1 : 0)
)
}, [connection, numRowsBefore])
const getNode = React.useCallback(
(index: number): ?Object => {
const node = connection?.edges[index - numRowsBefore]?.node
return node && convertNode ? convertNode(node) : node
},
[connection, numRowsBefore, convertNode]
)
return {
rowCount,
getNode,
onRowsRendered: setRenderedRows,
}
}
function getField(
node: graphql.FieldNode | graphql.OperationDefinitionNode,
name: string
): ?graphql.FieldNode {
if (!node.selectionSet) return null
const { selections } = node.selectionSet
return (selections.find(
n => n.kind === 'Field' && n.name.value === name
): any)
}
function analyzeQuery(
query: graphql.DocumentNode
): {
connectionPath: Array<any>,
variableNames: {
first: string,
after: string,
last: string,
before: string,
},
} {
const queryOp: ?graphql.OperationDefinitionNode = (query.definitions.find(
d => d.kind === 'OperationDefinition' && d.operation === 'query'
): any)
if (!queryOp) throw new Error(`failed to find query OperationDefinition`)
const {
selectionSet: { selections: querySelections },
} = queryOp
const connectionField: ?graphql.FieldNode = (querySelections.find(
s => s.kind === 'Field' && getField(s, 'pageInfo')
): any)
if (!connectionField) {
throw new Error(`failed to find connection field`)
}
const edgesField = getField(connectionField, 'edges')
if (!edgesField) throw new Error(`failed to find edges field`)
for (const name of ['cursor', 'node']) {
if (!getField(edgesField, name))
throw new Error(`missing edges.${name} field`)
}
const pageInfoField = getField(connectionField, 'pageInfo')
if (!pageInfoField) throw new Error(`failed to find pageInfo field`)
for (const name of [
'startCursor',
'endCursor',
'hasNextPage',
'hasPreviousPage',
]) {
if (!getField(pageInfoField, name))
throw new Error(`missing pageInfo.${name} field`)
}
const { alias, name } = connectionField
const connectionVars = ['first', 'after', 'last', 'before']
const variableNames = pipeline(
connectionField.arguments,
filter(a => a.value.kind === 'Variable'),
map(a => [a.name.value, a.value.name.value]),
fromPairs,
pick(connectionVars)
)
for (const name of connectionVars) {
if (typeof variableNames[name] !== 'string') {
throw new Error(`failed to find variable for ${name}`)
}
}
const connectionPath = [(alias || name).value]
return { connectionPath, variableNames }
}
/**
* This determines how many rows have been added or removed from the head of
* the connection. This can happen when:
* - rows get evicted from the head because the user scrolls far down
* - evicted rows at the head get refetched when the user scrolls back up
* - external operations like subscriptions add/remove rows
*/
function getStartOffset(
prevConnection: Connection,
nextConnection: Connection
): number {
const numAddedBefore = nextConnection.edges.findIndex(
e => e.cursor === prevConnection.pageInfo.startCursor
)
if (numAddedBefore >= 0) return -numAddedBefore
else {
const numDeletedBefore = prevConnection.edges.findIndex(
e => e.cursor === nextConnection.pageInfo.startCursor
)
if (numDeletedBefore >= 0) return numDeletedBefore
}
return 0
}
function mergeConnections({
connection,
more,
side,
maxNumEdgesToKeep,
}: {
connection: ?Connection,
more: Connection,
side: 'before' | 'after',
maxNumEdgesToKeep: number,
}): Connection {
const {
__typename,
pageInfo: { __typename: pageInfoTypename },
} = more
if (!connection) {
connection = {
__typename,
edges: [],
pageInfo: {
__typename: pageInfoTypename,
hasPreviousPage: true,
hasNextPage: true,
startCursor: null,
endCursor: null,
},
}
}
if (side === 'after') {
const edges = [
...takeWhile(e => e.cursor !== more.pageInfo.startCursor)(
connection.edges
),
...more.edges,
]
if (edges.length > maxNumEdgesToKeep) {
edges.splice(0, edges.length - maxNumEdgesToKeep)
}
const { startCursor, hasPreviousPage } = connection.pageInfo
return {
...more,
edges,
pageInfo: {
...more.pageInfo,
startCursor: edges[0]?.cursor,
hasPreviousPage:
edges[0]?.cursor === startCursor ? hasPreviousPage : true,
},
}
} else {
const edges = [
...more.edges,
...takeRightWhile(e => e.cursor !== more.pageInfo.endCursor)(
connection.edges
),
].slice(0, maxNumEdgesToKeep)
const { endCursor, hasNextPage } = connection.pageInfo
return {
...more,
edges,
pageInfo: {
...more.pageInfo,
endCursor: last(edges)?.cursor,
hasNextPage: last(edges)?.cursor === endCursor ? hasNextPage : true,
},
}
}
}
/**
* @flow
* @prettier
*/
// @$FlowFixMe
import { useRef } from 'react'
/**
* Gives you a persistent object from useRef() that you can tack whatever
* persistent local state onto that you need. It's handier than having to
* use .current all over the place.
* If you pass an argument, its props are assigned to the stash object, so
* you can use this to update values on each render if desired.
*/
export default function useStash<Props: Object>(
props: $Shape<Props> = {}
): Props {
const ref = useRef(props)
return Object.assign(ref.current, props)
}
@jedwards1211
Copy link
Author

Just updated with several bugfixes I made.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment