Skip to content

Instantly share code, notes, and snippets.

@jedwards1211
Last active June 3, 2020 12:00
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save jedwards1211/684500774eb2ff76dbb21bb9d9500716 to your computer and use it in GitHub Desktop.
Save jedwards1211/684500774eb2ff76dbb21bb9d9500716 to your computer and use it in GitHub Desktop.
useConnection (a React hook for infinite scrolling with Apollo, Relay-style Connections, and react-virtualized)
/**
* @flow
* @prettier
*/
import { type QueryRenderProps } from 'react-apollo'
import * as React from 'react'
import {
get,
takeRightWhile,
takeWhile,
filter,
map,
fromPairs,
pick,
mapKeys,
last,
} from 'lodash/fp'
import updateIn from '@jcoreio/mutate/updateIn'
import * as graphql from 'graphql'
import { pipeline } from '../util/pipeline'
import useStash from '../react-hooks/useStash'
export type GridProps = {
onRowsRendered: Function,
rowCount: number,
getNode: (index: number) => ?Object,
}/**
* @flow
* @prettier
*/
import { type QueryRenderProps } from 'react-apollo'
import * as React from 'react'
import {
get,
takeRightWhile,
takeWhile,
filter,
map,
fromPairs,
pick,
mapKeys,
last,
} from 'lodash/fp'
import updateIn from '@jcoreio/mutate/updateIn'
import * as graphql from 'graphql'
import { pipeline } from '../util/pipeline'
import useStash from '../react-hooks/useStash'
export type GridProps = {
onRowsRendered: Function,
rowCount: number,
getNode: (index: number) => ?Object,
}
type RenderedRange = {|
+startIndex: number,
+stopIndex: number,
+overscanStartIndex: number,
+overscanStopIndex: number,
|}
type Connection = {
+__typename: string,
+edges: $ReadOnlyArray<{
+__typename: string,
+cursor: string,
+node: ?Object,
}>,
+pageInfo: {
+__typename: string,
+startCursor: ?string,
+endCursor: ?string,
+hasPreviousPage: boolean,
+hasNextPage: boolean,
},
}
/**
* Provides props to turn a react-virtualized List/Table/Grid into an infinite
* scroll backed by a [Relay Cursor Connection](https://facebook.github.io/relay/graphql/connections.htm)
* from GraphQL.
*
* You provide the GraphQL query and this will automatically figure out where
* the connection field is. There must only be one connection in the query
* and the query should have at minimum (names within [ ] may vary):
*
* query (
* $[first]: Int
* $[after]: String
* $[last]: Int
* $[before]: String
* ) {
* [connection](
* first: $[first]
* after: $[after]
* last: $[last]
* before: $[before]
* ) {
* edges {
* cursor
* node {
* ...
* }
* }
* pageInfo {
* startCursor
* endCursor
* hasPreviousPage
* hasNextPage
* }
* }
* }
*
* You should use the returned getNode function to get the node at the given
* row, because an offset may be applied to keep rows in the same position
* if edges get evicted from the head of the connection because the user scrolls
* way down.
*/
export default function useConnection(
queryRenderProps: QueryRenderProps<any, any>,
options: {
query: graphql.DocumentNode,
minFetchSize: number,
maxFetchSize?: number,
maxNumEdgesToKeep: number,
convertNode?: ?(Object) => Object,
}
): GridProps {
const { data, fetchMore, refetch } = queryRenderProps
const { query, minFetchSize, maxNumEdgesToKeep, convertNode } = options
const { connectionPath, variableNames } = React.useMemo(
() => analyzeQuery(query),
[query]
)
const maxFetchSize = Math.min(
maxNumEdgesToKeep,
options.maxFetchSize || minFetchSize
)
const clampFetchSize = FetchSize =>
Math.max(minFetchSize, Math.min(maxFetchSize, FetchSize))
const [renderedRows, setRenderedRows] = React.useState<RenderedRange>({
overscanStartIndex: 0,
overscanStopIndex: 0,
startIndex: 0,
stopIndex: 0,
})
const [fetching, setFetching] = React.useState(false)
const stash: {
connection?: ?Connection,
numRowsBefore: number,
} = useStash()
const prevConnection = stash.connection
const connection = get(connectionPath)(data)
stash.connection = connection
// numRowsBefore is the offset of the first row in the connection.
// E.g. if we evict the first 10 rows, we add 10 to numRowsBefore so that the
// rest of the rows still show in the same position, and we can reload
// those first 10 rows when the user scrolls back to them.
// If no rows at the head are evicted but there's a previous page, we use
// numRowsBefore = 1 so that the list will show a loading row at the top
// and it will trigger a fetch when visible.
if (connection) {
if (prevConnection) {
stash.numRowsBefore = connection.pageInfo.hasPreviousPage
? Math.max(
1,
stash.numRowsBefore + getStartOffset(prevConnection, connection)
)
: 0
} else {
stash.numRowsBefore = connection.pageInfo.hasPreviousPage ? 1 : 0
}
} else {
stash.numRowsBefore = 0
}
const { numRowsBefore } = stash
React.useEffect(() => {
if (fetching || !connection) return
const {
startIndex,
stopIndex,
overscanStartIndex,
overscanStopIndex,
} = renderedRows
const {
pageInfo: { hasPreviousPage, hasNextPage, startCursor, endCursor },
edges,
} = connection
const done = () => setFetching(false)
// if the user scrolls all the way back to the top and rows are missing
// there, just refetch the head of the list
if (startIndex === 0 && numRowsBefore > 0) {
setFetching(true)
// I had been putting first: maxFetchSize in the variables here, but
// unfortunately that puts the results under a different apollo cache
// key instead of replacing the data in the existing cache key.
refetch().then(done, done)
return
}
// first see if rows in the visible range need to be fetched
// otherwise, see if any rows in the overscan range need to be fetched
let fetchSide
if (startIndex < numRowsBefore && hasPreviousPage) {
fetchSide = 'before'
} else if (stopIndex >= numRowsBefore + edges.length && hasNextPage) {
fetchSide = 'after'
} else if (overscanStartIndex < numRowsBefore && hasPreviousPage) {
fetchSide = 'before'
} else if (
overscanStopIndex >= numRowsBefore + edges.length &&
hasNextPage
) {
fetchSide = 'after'
} else {
return
}
setFetching(true)
const fetchVariables =
fetchSide === 'before'
? {
last: clampFetchSize(numRowsBefore - overscanStartIndex),
before: startCursor,
first: null,
after: null,
}
: {
first: clampFetchSize(
overscanStopIndex + 1 - numRowsBefore - edges.length
),
after: endCursor,
last: null,
before: null,
}
fetchMore({
variables: mapKeys(v => variableNames[v])(fetchVariables),
updateQuery: (prev: Object, { fetchMoreResult, variables }) =>
updateIn(prev, connectionPath, (connection: ?Connection): Connection =>
mergeConnections({
connection,
more: get(connectionPath)(fetchMoreResult),
side: fetchSide,
renderedRows,
maxNumEdgesToKeep,
})
),
}).then(done, done)
}, [
queryRenderProps,
renderedRows,
numRowsBefore,
fetching,
minFetchSize,
maxFetchSize,
])
const rowCount = React.useMemo((): number => {
if (!connection) return 1
return (
numRowsBefore +
connection.edges.length +
(connection.pageInfo.hasNextPage ? 1 : 0)
)
}, [connection, numRowsBefore])
const getNode = React.useCallback(
(index: number): ?Object => {
const node = connection?.edges[index - numRowsBefore]?.node
return node && convertNode ? convertNode(node) : node
},
[connection, numRowsBefore, convertNode]
)
return {
rowCount,
getNode,
onRowsRendered: setRenderedRows,
}
}
function getField(
node: graphql.FieldNode | graphql.OperationDefinitionNode,
name: string
): ?graphql.FieldNode {
if (!node.selectionSet) return null
const { selections } = node.selectionSet
return (selections.find(
n => n.kind === 'Field' && n.name.value === name
): any)
}
function analyzeQuery(
query: graphql.DocumentNode
): {
connectionPath: Array<any>,
variableNames: {
first: string,
after: string,
last: string,
before: string,
},
} {
const queryOp: ?graphql.OperationDefinitionNode = (query.definitions.find(
d => d.kind === 'OperationDefinition' && d.operation === 'query'
): any)
if (!queryOp) throw new Error(`failed to find query OperationDefinition`)
const {
selectionSet: { selections: querySelections },
} = queryOp
const connectionField: ?graphql.FieldNode = (querySelections.find(
s => s.kind === 'Field' && getField(s, 'pageInfo')
): any)
if (!connectionField) {
throw new Error(`failed to find connection field`)
}
const edgesField = getField(connectionField, 'edges')
if (!edgesField) throw new Error(`failed to find edges field`)
for (const name of ['cursor', 'node']) {
if (!getField(edgesField, name))
throw new Error(`missing edges.${name} field`)
}
const pageInfoField = getField(connectionField, 'pageInfo')
if (!pageInfoField) throw new Error(`failed to find pageInfo field`)
for (const name of [
'startCursor',
'endCursor',
'hasNextPage',
'hasPreviousPage',
]) {
if (!getField(pageInfoField, name))
throw new Error(`missing pageInfo.${name} field`)
}
const { alias, name } = connectionField
const connectionVars = ['first', 'after', 'last', 'before']
const variableNames = pipeline(
connectionField.arguments,
filter(a => a.value.kind === 'Variable'),
map(a => [a.name.value, a.value.name.value]),
fromPairs,
pick(connectionVars)
)
for (const name of connectionVars) {
if (typeof variableNames[name] !== 'string') {
throw new Error(`failed to find variable for ${name}`)
}
}
const connectionPath = [(alias || name).value]
return { connectionPath, variableNames }
}
/**
* This determines how many rows have been added or removed from the head of
* the connection. This can happen when:
* - rows get evicted from the head because the user scrolls far down
* - evicted rows at the head get refetched when the user scrolls back up
* - external operations like subscriptions add/remove rows
*/
function getStartOffset(
prevConnection: Connection,
nextConnection: Connection
): number {
const numAddedBefore = nextConnection.edges.findIndex(
e => e.cursor === prevConnection.pageInfo.startCursor
)
if (numAddedBefore >= 0) return -numAddedBefore
else {
const numDeletedBefore = prevConnection.edges.findIndex(
e => e.cursor === nextConnection.pageInfo.startCursor
)
if (numDeletedBefore >= 0) return numDeletedBefore
}
return 0
}
function mergeConnections({
connection,
more,
side,
renderedRows,
maxNumEdgesToKeep,
}: {
connection: ?Connection,
more: Connection,
side: 'before' | 'after',
renderedRows: RenderedRange,
maxNumEdgesToKeep: number,
}): Connection {
const {
__typename,
pageInfo: { __typename: pageInfoTypename },
} = more
if (!connection) {
connection = {
__typename,
edges: [],
pageInfo: {
__typename: pageInfoTypename,
hasPreviousPage: true,
hasNextPage: true,
startCursor: null,
endCursor: null,
},
}
}
if (side === 'after') {
let edges = [
...takeWhile(e => e.cursor !== more.pageInfo.startCursor)(
connection.edges
),
...more.edges,
]
if (edges.length > maxNumEdgesToKeep) {
let start = edges.length - maxNumEdgesToKeep
let end = edges.length
if (start > renderedRows.startIndex) {
start = renderedRows.startIndex
end = Math.max(
renderedRows.stopIndex + 1,
end - start + renderedRows.startIndex
)
}
edges = edges.slice(start, end)
}
const { startCursor, hasPreviousPage } = connection.pageInfo
return {
...more,
edges,
pageInfo: {
...more.pageInfo,
startCursor: edges[0]?.cursor,
hasPreviousPage:
edges[0]?.cursor === startCursor ? hasPreviousPage : true,
},
}
} else {
let edges = [
...more.edges,
...takeRightWhile(e => e.cursor !== more.pageInfo.endCursor)(
connection.edges
),
]
if (edges.length > maxNumEdgesToKeep) {
let start = 0
let end = maxNumEdgesToKeep
if (end <= renderedRows.stopIndex) {
start = Math.min(
renderedRows.startIndex,
start + renderedRows.stopIndex - end + 1
)
end = renderedRows.stopIndex
}
edges = edges.slice(start, end)
}
const { endCursor, hasNextPage } = connection.pageInfo
return {
...more,
edges,
pageInfo: {
...more.pageInfo,
endCursor: last(edges)?.cursor,
hasNextPage: last(edges)?.cursor === endCursor ? hasNextPage : true,
},
}
}
}
type RenderedRange = {|
+startIndex: number,
+stopIndex: number,
+overscanStartIndex: number,
+overscanStopIndex: number,
|}
type Connection = {
+__typename: string,
+edges: $ReadOnlyArray<{
+__typename: string,
+cursor: string,
+node: ?Object,
}>,
+pageInfo: {
+__typename: string,
+startCursor: ?string,
+endCursor: ?string,
+hasPreviousPage: boolean,
+hasNextPage: boolean,
},
}
/**
* Provides props to turn a react-virtualized List/Table/Grid into an infinite
* scroll backed by a [Relay Cursor Connection](https://facebook.github.io/relay/graphql/connections.htm)
* from GraphQL.
*
* You provide the GraphQL query and this will automatically figure out where
* the connection field is. There must only be one connection in the query
* and the query should have at minimum (names within [ ] may vary):
*
* query (
* $[first]: Int
* $[after]: String
* $[last]: Int
* $[before]: String
* ) {
* [connection](
* first: $[first]
* after: $[after]
* last: $[last]
* before: $[before]
* ) {
* edges {
* cursor
* node {
* ...
* }
* }
* pageInfo {
* startCursor
* endCursor
* hasPreviousPage
* hasNextPage
* }
* }
* }
*
* You should use the returned getNode function to get the node at the given
* row, because an offset may be applied to keep rows in the same position
* if edges get evicted from the head of the connection because the user scrolls
* way down.
*/
export default function useConnection(
queryRenderProps: QueryRenderProps<any, any>,
options: {
query: graphql.DocumentNode,
minFetchSize: number,
maxFetchSize?: number,
maxNumEdgesToKeep: number,
convertNode?: ?(Object) => Object,
}
): GridProps {
const { data, fetchMore, refetch } = queryRenderProps
const { query, minFetchSize, maxNumEdgesToKeep, convertNode } = options
const { connectionPath, variableNames } = React.useMemo(
() => analyzeQuery(query),
[query]
)
const maxFetchSize = Math.min(
maxNumEdgesToKeep,
options.maxFetchSize || minFetchSize
)
const clampFetchSize = FetchSize =>
Math.max(minFetchSize, Math.min(maxFetchSize, FetchSize))
const [renderedRows, setRenderedRows] = React.useState<RenderedRange>({
overscanStartIndex: 0,
overscanStopIndex: 0,
startIndex: 0,
stopIndex: 0,
})
const [fetching, setFetching] = React.useState(false)
const stash: {
connection?: ?Connection,
numRowsBefore: number,
} = useStash()
const prevConnection = stash.connection
const connection = get(connectionPath)(data)
stash.connection = connection
// numRowsBefore is the offset of the first row in the connection.
// E.g. if we evict the first 10 rows, we add 10 to numRowsBefore so that the
// rest of the rows still show in the same position, and we can reload
// those first 10 rows when the user scrolls back to them.
// If no rows at the head are evicted but there's a previous page, we use
// numRowsBefore = 1 so that the list will show a loading row at the top
// and it will trigger a fetch when visible.
if (connection) {
if (prevConnection) {
stash.numRowsBefore = connection.pageInfo.hasPreviousPage
? Math.max(
1,
stash.numRowsBefore + getStartOffset(prevConnection, connection)
)
: 0
} else {
stash.numRowsBefore = connection.pageInfo.hasPreviousPage ? 1 : 0
}
} else {
stash.numRowsBefore = 0
}
const { numRowsBefore } = stash
React.useEffect(() => {
if (fetching || !connection) return
const {
startIndex,
stopIndex,
overscanStartIndex,
overscanStopIndex,
} = renderedRows
const {
pageInfo: { hasPreviousPage, hasNextPage, startCursor, endCursor },
edges,
} = connection
const done = () => setFetching(false)
// if the user scrolls all the way back to the top and rows are missing
// there, just refetch the head of the list
if (startIndex === 0 && numRowsBefore > 0) {
setFetching(true)
refetch({
first: maxFetchSize,
after: null,
last: null,
before: null,
}).then(done, done)
return
}
// first see if rows in the visible range need to be fetched
// otherwise, see if any rows in the overscan range need to be fetched
let fetchSide
if (startIndex < numRowsBefore && hasPreviousPage) {
fetchSide = 'before'
} else if (stopIndex >= numRowsBefore + edges.length && hasNextPage) {
fetchSide = 'after'
} else if (overscanStartIndex < numRowsBefore && hasPreviousPage) {
fetchSide = 'before'
} else if (
overscanStopIndex >= numRowsBefore + edges.length &&
hasNextPage
) {
fetchSide = 'after'
} else {
return
}
setFetching(true)
const fetchVariables =
fetchSide === 'before'
? {
last: clampFetchSize(numRowsBefore - overscanStartIndex),
before: startCursor,
first: null,
after: null,
}
: {
first: clampFetchSize(
overscanStopIndex + 1 - numRowsBefore - edges.length
),
after: endCursor,
last: null,
before: null,
}
fetchMore({
variables: mapKeys(v => variableNames[v])(fetchVariables),
updateQuery: (prev: Object, { fetchMoreResult, variables }) =>
updateIn(prev, connectionPath, (connection: ?Connection): Connection =>
mergeConnections({
connection,
more: get(connectionPath)(fetchMoreResult),
side: fetchSide,
maxNumEdgesToKeep,
})
),
}).then(done, done)
}, [
queryRenderProps,
renderedRows,
numRowsBefore,
fetching,
minFetchSize,
maxFetchSize,
])
const rowCount = React.useMemo((): number => {
if (!connection) return 1
return (
numRowsBefore +
connection.edges.length +
(connection.pageInfo.hasNextPage ? 1 : 0)
)
}, [connection, numRowsBefore])
const getNode = React.useCallback(
(index: number): ?Object => {
const node = connection?.edges[index - numRowsBefore]?.node
return node && convertNode ? convertNode(node) : node
},
[connection, numRowsBefore, convertNode]
)
return {
rowCount,
getNode,
onRowsRendered: setRenderedRows,
}
}
function getField(
node: graphql.FieldNode | graphql.OperationDefinitionNode,
name: string
): ?graphql.FieldNode {
if (!node.selectionSet) return null
const { selections } = node.selectionSet
return (selections.find(
n => n.kind === 'Field' && n.name.value === name
): any)
}
function analyzeQuery(
query: graphql.DocumentNode
): {
connectionPath: Array<any>,
variableNames: {
first: string,
after: string,
last: string,
before: string,
},
} {
const queryOp: ?graphql.OperationDefinitionNode = (query.definitions.find(
d => d.kind === 'OperationDefinition' && d.operation === 'query'
): any)
if (!queryOp) throw new Error(`failed to find query OperationDefinition`)
const {
selectionSet: { selections: querySelections },
} = queryOp
const connectionField: ?graphql.FieldNode = (querySelections.find(
s => s.kind === 'Field' && getField(s, 'pageInfo')
): any)
if (!connectionField) {
throw new Error(`failed to find connection field`)
}
const edgesField = getField(connectionField, 'edges')
if (!edgesField) throw new Error(`failed to find edges field`)
for (const name of ['cursor', 'node']) {
if (!getField(edgesField, name))
throw new Error(`missing edges.${name} field`)
}
const pageInfoField = getField(connectionField, 'pageInfo')
if (!pageInfoField) throw new Error(`failed to find pageInfo field`)
for (const name of [
'startCursor',
'endCursor',
'hasNextPage',
'hasPreviousPage',
]) {
if (!getField(pageInfoField, name))
throw new Error(`missing pageInfo.${name} field`)
}
const { alias, name } = connectionField
const connectionVars = ['first', 'after', 'last', 'before']
const variableNames = pipeline(
connectionField.arguments,
filter(a => a.value.kind === 'Variable'),
map(a => [a.name.value, a.value.name.value]),
fromPairs,
pick(connectionVars)
)
for (const name of connectionVars) {
if (typeof variableNames[name] !== 'string') {
throw new Error(`failed to find variable for ${name}`)
}
}
const connectionPath = [(alias || name).value]
return { connectionPath, variableNames }
}
/**
* This determines how many rows have been added or removed from the head of
* the connection. This can happen when:
* - rows get evicted from the head because the user scrolls far down
* - evicted rows at the head get refetched when the user scrolls back up
* - external operations like subscriptions add/remove rows
*/
function getStartOffset(
prevConnection: Connection,
nextConnection: Connection
): number {
const numAddedBefore = nextConnection.edges.findIndex(
e => e.cursor === prevConnection.pageInfo.startCursor
)
if (numAddedBefore >= 0) return -numAddedBefore
else {
const numDeletedBefore = prevConnection.edges.findIndex(
e => e.cursor === nextConnection.pageInfo.startCursor
)
if (numDeletedBefore >= 0) return numDeletedBefore
}
return 0
}
function mergeConnections({
connection,
more,
side,
maxNumEdgesToKeep,
}: {
connection: ?Connection,
more: Connection,
side: 'before' | 'after',
maxNumEdgesToKeep: number,
}): Connection {
const {
__typename,
pageInfo: { __typename: pageInfoTypename },
} = more
if (!connection) {
connection = {
__typename,
edges: [],
pageInfo: {
__typename: pageInfoTypename,
hasPreviousPage: true,
hasNextPage: true,
startCursor: null,
endCursor: null,
},
}
}
if (side === 'after') {
const edges = [
...takeWhile(e => e.cursor !== more.pageInfo.startCursor)(
connection.edges
),
...more.edges,
]
if (edges.length > maxNumEdgesToKeep) {
edges.splice(0, edges.length - maxNumEdgesToKeep)
}
const { startCursor, hasPreviousPage } = connection.pageInfo
return {
...more,
edges,
pageInfo: {
...more.pageInfo,
startCursor: edges[0]?.cursor,
hasPreviousPage:
edges[0]?.cursor === startCursor ? hasPreviousPage : true,
},
}
} else {
const edges = [
...more.edges,
...takeRightWhile(e => e.cursor !== more.pageInfo.endCursor)(
connection.edges
),
].slice(0, maxNumEdgesToKeep)
const { endCursor, hasNextPage } = connection.pageInfo
return {
...more,
edges,
pageInfo: {
...more.pageInfo,
endCursor: last(edges)?.cursor,
hasNextPage: last(edges)?.cursor === endCursor ? hasNextPage : true,
},
}
}
}
/**
* @flow
* @prettier
*/
// @$FlowFixMe
import { useRef } from 'react'
/**
* Gives you a persistent object from useRef() that you can tack whatever
* persistent local state onto that you need. It's handier than having to
* use .current all over the place.
* If you pass an argument, its props are assigned to the stash object, so
* you can use this to update values on each render if desired.
*/
export default function useStash<Props: Object>(
props: $Shape<Props> = {}
): Props {
const ref = useRef(props)
return Object.assign(ref.current, props)
}
@jedwards1211
Copy link
Author

I realized a potential bug I need to fix: if it fetches too many rows relative to the limit that it keeps in memory, it could probably end up evicting rows in the visible range, refetching them on the next render, and repeatedly thrashing like this.

@jedwards1211
Copy link
Author

Just updated with several bugfixes I made.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment