Skip to content

Instantly share code, notes, and snippets.

@simon-mo
Last active October 31, 2023 08:25
Show Gist options
  • Save simon-mo/126927ca9c97c650c130fa893d1e7ba2 to your computer and use it in GitHub Desktop.
Save simon-mo/126927ca9c97c650c130fa893d1e7ba2 to your computer and use it in GitHub Desktop.
abstract author bibliography title
Speculative decoding is a pivotal technique to accelerate the inference of large language models (LLMs) by employing a smaller draft model to predict the target model’s outputs. However, its efficacy can be limited due to the low predictive accuracy of the draft model, particularly when faced with diverse text inputs and a significant capability gap between the draft and target models. We introduce online speculative decoding to address this challenge. The main idea is to continually update (multiple) draft model(s) on observed user query data using the abundant excess computational power in an LLM serving cluster. Given that LLM inference is memory-bounded, the surplus computational power in a typical LLM serving cluster can be repurposed for online retraining of draft models, thereby making the training cost-neutral. Since the query distribution of an LLM service is relatively simple, retraining on query distribution enables the draft model to more accurately predict the target model’s outputs, particularly on data originating from query distributions. As the draft model evolves online, it aligns with the query distribution in real time, mitigating distribution shifts. We develop a prototype of online speculative decoding based on online knowledge distillation and evaluate it using both synthetic and real query data on several popular LLMs. The results show a substantial increase in the token acceptance rate by 0.1 to 0.65, which translates into 1.22$\times$ to 3.06$\times$ latency reduction. Code is available at `https://github.com/LiuXiaoxuanPKU/OSD`.
**Xiaoxuan Liu$\,\:$$^{1}$ $\qquad$ Lanxiang Hu$^{2}$$\qquad$ Peter Bailis$^{3}$$\qquad$ Ion Stoica$^{1}$** **Zhijie Deng$^{4}$$\thanks{Corresponding author}$ $\qquad$ Alvin Cheung$^{1}$$\qquad$ Hao Zhang$^{2*}$** $^{1}$ UC Berkeley$^{2}$ UCSD$^{3}$ Sisu Data$^{4}$ SJTU `{xiaoxuanliu, istoica, akcheung}@cs.berkeley.edu` `{lah003, haozhang}@ucsd.edu, peter@sisudata.com, zhijied@sjtu.edu.cn`
iclr2024_conference.bib
Online Speculative Decoding

Introduction

Large language models (LLMs) such as GPT-4 , Claude , and Llama  are rapidly reinventing today’s applications. Many companies are racing to deploy LLMs in their vertical domains, such as search, chatbots, and virtual assistants. Since most of these applications demand low latency, optimizing LLM serving latency is of vital importance and can directly translate into better quality of service and cost reduction.

The latency of today’s LLM service is unfortunately very high. This is primarily because serving a user query requires multiple serial evaluations of the LLM, each generating only one token of the response. An emerging solution to reduce the latency is speculative decoding. Speculative decoding employs a smaller model to speculate multiple output tokens of the target (large) model, then lets the target LLM verify these speculations in parallel. Then, if the verification of a token fails, the large model must recompute from that point. Therefore, the performance of speculative decoding primarily depends on the speculation accuracy of the small model. In the presence of diverse text inputs, the accuracy of existing speculative decoding methods is unfortunately not very high, due to the capability gap between the draft and target model. Employing a larger, more accurate model however defeats the purpose of speculative decoding as it potentially increases latency.

To address this challenge, we introduce a novel method, online speculative decoding, specifically designed for online LLM services. The method leverages the abundant redundant compute, termed as “spare flops,” available in a typical LLM serving cluster to continuously retrain (multiple) small draft models through online learning on query data posted to the LLM service. Our approach is simple and offers several significant advantages. First, user queries to a specific LLM service often exhibit a common domain-specific distribution , reflecting shared usage patterns. While accurately speculating the larger model’s outputs on any diverse input is challenging, it is feasible to enhance the draft model’s prediction accuracy, only for similar inputs posted to the service, characterized by the query distribution. This can be achieved by finetuning the draft model on user query distribution or finetuning multiple draft models, each on a cluster of the query distribution, and selecting the appropriately specialized draft model to speculate based on the class of inputs they are trained on. As shown in §5.2, we show that it is possible to train multiple draft models, each for a different language or topic. Second, the primary bottleneck for transformer-based LLM inference is the accelerator’s memory bandwidth, as generating each word requires loading the model weights from HBM to SRAM as well as reading the KV cache on all previous words. This results in a substantial amount of unused compute, especially during non-spike traffic hours , in an LLM serving cluster. We demonstrate that these spare FLOPs can be effectively repurposed for online retraining of draft models, with inconspicuous retraining cost (§4.2.2). Third, since tuning is performed online, the draft models continuously evolve over time based on the observed query data, which ensures high speculation accuracy even when faced with shifts in query distribution.

Based on these insights, we develop an online speculative decoding framework to improve the efficiency of online LLM serving. To align the draft model with the target model on a newly observed user query, we develop a new online learning algorithm based on Generalized Knowledge Distillation (GKD) . The algorithm keeps track of the recent queries that the draft model has speculated incorrectly, and forces the draft model to emulate the target model’s outputs on these queries. The algorithm performs GKD-based gradient update opportunistically only when spare flops are available, hiding the overhead.

Online speculative decoding overview. For each prompt, the draft model suggests multiple tokens in a single step. The target model then verifies these tokens, accepting some and rejecting others. If the student proposes incorrect tokens, both the draft and target distributions are stored in a buffer. Once the buffer exceeds a specified threshold, the draft model is updated by calculating the loss between the draft and target distributions using various distance metrics.

In summary, this paper makes the following contributions:

  • We introduce online speculative decoding to reduce LLM serving latency by adapting (multiple) draft models on the fly using query data and knowledge distillation.

  • We explore various GKD methods for constructing draft models and identify the most effective variants, suggesting them as superior alternatives to existing finetuning methods in offline settings.

  • Our method demonstrates a significant improvement in token acceptance rate by 10-65% on diverse datasets, translating to 1.2-3.1$\times$ reduction in latency theoretically, with a negligible additional cost. It surpasses existing methods which construct static draft models using fine-tuning or distillation on offline datasets, and matches the hypothetical accuracy achieved if all query data were available a priori.

Related Work

LLMs have become pervasive in today’s AI applications, underscoring the importance of optimizing LLM inference. Numerous system optimizations have been developed to optimize the throughput of LLM serving . This paper particularly concentrates on a significant strand of research, speculative decoding, aimed at reducing the latency of LLM inference.

Speculative decoding. Speculative decoding  accelerates LLM decoding by employing a (small) draft model to predict the outputs of the larger target model, which are then verified by the target model. Typically, the draft model, while having fewer parameters, is pretrained using the same training data as the target mode, resulting in a negotiable inference cost but with compromised capability. If the draft model can correctly predict more than one token per verification step, the memory I/O for accessing the model weights and KV cache at inference is amortized across multiple output tokens, thereby reduces latency, especially since LLM inference is often constrained by GPU HBM bandwidth. The efficacy of speculative decoding largely hinges on the draft model’s ability to accurately predict the target model’s outputs. Existing work improves the speculation accuracy by using multiple collectively boosted  or staged  draft models, or retraining the target model with auxiliary prediction heads as a draft model . These methods predominantly assume a static draft model post-deployment. In contrast, our work introduces a framework that actively adapts the draft model to the evolving user query distribution on the fly, irrespective of the draft model’s construction.

Distillation for auto-regressive models. Knowledge distillation (KD) is a framework to generate smaller models that emulate the performance of larger models. However, KD in its conventional form has been observed to be less effective for LLMs. extend KD to autoregressive LLMs by decoding from the student model and optimizing the reserve KL divergence between students and teachers. Further, introduce generalized knowledge distillation (GKD) to optimize a linear combination of the forward KL and reverse KL between teacher and student, using a blend of teacher- and student-sampled data. Drawing inspiration from both works, our paper applies KD to speculative decoding for LLMs. We empirically determine the most effective KD variant for maximizing the draft model’s accuracy, and extend it to dynamically generate draft models for online LLM services.

Background

We first briefly review speculative decoding , a critical technique that accelerates inference of a large target LLM $p(\cdot|{\bm{x}})$ with token proposals from a small draft model $q_{\bm{\theta}}(\cdot|{\bm{x}})$. ${\bm{x}}$ denotes the concatenation of the input prompt and already generated tokens. The two distributions are both auto-regressive. We emphasize the parameters ${\bm{\theta}}$ of the draft model because we usually need to tailor them according to the target LLM for more substantial acceleration.

Speculative decoding uses a (small) draft model to propose $k$ tokens ${{\bm{y}}} \triangleq { y_i}{i=1}^k \sim q{\bm{\theta}}(\cdot | {\bm{x}})$, and let the target LLM estimate the $k+1$ probabilities, ${p({y}|{\bm{x}}, {{\bm{y}}}{<i})}{i=1}^{k+1}$1, in parallel. With $i$ rising from $1$ to $k$, speculative decoding accepts the proposal ${y}i$ if $u \leq p(y_i|{\bm{x}}, {{\bm{y}}}{<i}) / q_{\bm{\theta}}({y}i|{\bm{x}}, {{\bm{y}}}{<i})$ where $u \sim U[0,1]$; otherwise exits. Let $a$ denote the number of accepted tokens, which takes values in ${0,\dots, k}$. We can sample an additional token ${y}{a+1}$ from the following distribution $$p'(y) = \begin{cases} p(y|{\bm{x}}, {{\bm{y}}}{<a+1}) & \text{if $a = k$}\ \mathrm{norm}(\max(0, p(y|{\bm{x}}, {{\bm{y}}}{<a+1}) - q{\bm{\theta}}(y|{\bm{x}}, {{\bm{y}}}_{<a+1}))) & \text{otherwise} \end{cases}$$ where $\mathrm{norm}(\cdot)$ makes the probabilities over the vocabulary sum to $1$.

Prior work has shown that the resulting samples $\tilde{{\bm{y}}} \triangleq {{y}1, \dots, y{a+1}}$ strictly follow the distribution of the target LLM $p(\cdot|{\bm{x}})$ . We concatenate $\tilde{{\bm{y}}}$ to ${\bm{x}}$ and repeat the above process until meeting ⟨EOS⟩. Each run of the target LLM generates $a+1$ tokens with $a\geq0$. This ensures that at least one new token is generated even in the worst case. The generation process can be significantly accelerated if the draft LLM better approximates the target one, particularly $a$ is larger for each target LLM run.

Expected acceptance rate & speedup. The acceptance rate, denoted as $\alpha$, serves as a measure of how closely the draft model approximates the target model. It is defined as the expected probability that speculative decoding will accept a proposal token given the prompt $y_i \sim q_{\bm{\theta}}(y_i|{\bm{x}}, {{\bm{y}}}_{&lt;i})$. This rate directly influences the expected length ($\mathbb{E}(|\tilde{{\bm{y}}}|)$) of $\tilde{{\bm{y}}}$ for each target LLM run and the speedup brought by speculative decoding.

Assuming that the $k + 1$ simultaneous evaluations of the target LLM $p$ take roughly the same amount of time as generating a single token in parallel, let $c$ be the time ratio for a single run between $q_{\bm{\theta}}$ and $p$. The expected generation length of a single target LLM run and the speedup in the total wall time due to speculative decoding is represented as : $$\label{eq:gen_len} \mathbb{E}(|\tilde{{\bm{y}}}|) = \frac{1 - \alpha^{k+1}}{1-\alpha},\quad \mathbb{E}(speedup)=\frac{1-\alpha^{k+1}}{(1-\alpha)(kc+1)}.$$ We depict the speedup for varying values of $\alpha$ in Figure 2, which demonstrates the importance of $\alpha$ in affecting the speedup.

Speculative decoding speedups for varying values of α in Figure 2. For smaller α values, speculative decoding may even degrade performance (indicated by a speedup  < 1), particularly when the draft model is sizeable. Furthermore, the relationship between speedup and α is superlinear; doubling the acceptance rate can yield a speedup exceeding 2×.

Observation. Interestingly, we can actually enhance $\alpha$ based on a key observation: the speculative decoding process inherently identifies the inaccuracies of the small draft LLM and offers correct solutions for these inaccuracies. This essentially means that we receive valuable insights on the areas and strategies to refine the draft model at no additional cost. Viewed through the lens of online learning, we can effortlessly accumulate a set of input-output pairs, denoted as $([{\bm{x}}, {\bm{y}}{<a+1}], p(y|{\bm{x}}, {{\bm{y}}}{<a+1}))$, that have yet to be assimilated by the draft LLM, paving the way for its subsequent optimization. Given the reduced size of the draft model (for instance, it may be over $20\times$ smaller than the target model), its tuning is not only efficient but also viable for real-time online adjustments. Prior work  has primarily approached speculative decoding in an offline manner, meaning the draft model remains static during online deployment. We next develop online speculative decoding to bridge this gap.

Online Speculative Decoding

We propose the online speculative decoding approach to update the draft model dynamically for more effective suggestions. We frame the learning problem based on the aforementioned auxiliary information as online knowledge distillation, where the teacher and student models correspond to the target and draft LLMs in speculative decoding, respectively. We elaborate on the details below.

Knowledge Distillation for Speculative Decoding

Knowledge distillation is a general framework to align the predictive distribution of a small model (i.e., student model) with that of a larger one (i.e., teacher model). Prior research has utilized knowledge distillation to compress neural networks, resulting in decreased inference costs and memory requirements. We posit that knowledge distillation is highly effective for speculative decoding. In this approach, the draft model acts as the student and the target model serves as the teacher. During speculative decoding, we possess complete information on both the proposed and verified probabilities of each token. This information helps to construct objectives for distilling the draft model, aligning its output distributions with those of the target model and thereby improving the token acceptance rate of the draft model. The distillation loss generally takes the form of: $$\label{eq:distill} \small \begin{aligned} \ell({\bm{\theta}}) &= \frac{1}{n_B}\sum_{{\bm{x}}^{(i)} \in \mathcal{B}} \ell({\bm{x}}^{(i)}, {\bm{\theta}}), \quad \ell({\bm{x}}, {\bm{\theta}}) = D ({p(\cdot|{\bm{x}})} \Vert {q_{\bm{\theta}}(\cdot|{\bm{x}})} ),% \ % &= \frac{1}{n_B}\sum_{\vx \in \mathcal{B}} \sum_{t=1} \KL(q_\vtheta(y_t|\vx, \vy_{<t}) \Vert p(y_t|\vx, \vy_{<t})) \ \end{aligned}$$ where $\mathcal{B} = {{\bm{x}}^{(i)}}_{i=1}^{n_B}$ denotes a batch of inputs and $D$ denotes some distance measure.

Distance measure. In the case of auto-regressive models, the prediction distribution is categorical at each token. Often, we can augment the predicted logits with a tunable temperature $\tau$ for softmax transformation. We then use the popular forward KL and reverse KL (RKL), as well as their mixture (i.e., the JSD divergence) to instantiate $D$ : $$\small \begin{aligned} &\ell_{KL}({\bm{x}}, {\bm{\theta}}) = D_{\mathrm{KL}}( {p(\cdot|{\bm{x}})}\Vert {q_{\bm{\theta}}(\cdot|{\bm{x}})}), \ &\ell_{RKL}({\bm{x}}, {\bm{\theta}}) = D_{\mathrm{KL}}({q_{\bm{\theta}}(\cdot|{\bm{x}})} \Vert {p(\cdot|{\bm{x}})}), \ &\ell_{{JSD}[\beta]} ({\bm{x}}, {\bm{\theta}}) = \beta D_{\mathrm{KL}}\left({p(\cdot|{\bm{x}})} \big\Vert {p}^\beta_{\bm{\theta}}(\cdot|{\bm{x}})\right)+ (1-\beta) D_{\mathrm{KL}}\left({q_{\bm{\theta}}(\cdot|{\bm{x}})} \big\Vert {p}^\beta_{\bm{\theta}}(\cdot|{\bm{x}})\right), \end{aligned}$$ where ${p}^\beta_{\bm{\theta}}(\cdot|{\bm{x}}) \triangleq \beta{p(\cdot|{\bm{x}})} + (1-\beta){q_{\bm{\theta}}(\cdot|{\bm{x}})}$. These objectives diverge from the conventionally used label-based fine-tuning objectives in speculative decoding, as highlighted in . As shown in Section 5.1, objectives based on the KL divergence prove to be more effective. This is because distributions convey richer information than mere labels, thereby enhancing their capability to guide the student model . Additionally, these objectives enhance convergence rates  and bolster calibration. The reverse KL is highlighted for its mode-seeking behavior, offering unique advantages . In our study, and in alignment with previous research , we empirically determine that the optimal distance measure can vary depending on the tasks and the relative capacities of the teacher and student models (see §5.1).

Sampling and gradient estimation. Estimating the above objectives involves the expectation over $q_{\bm{\theta}}(\cdot|{\bm{x}})$ or $p(\cdot|{\bm{x}})$, which should be expanded recursively. Once the recursion depth exceeds $1$, we can not analytically compute $D_{\mathrm{KL}}$ but hinge on Monte Carlo approximation. When sampling from $q_{\bm{\theta}}(\cdot|{\bm{x}})$, we should differentiate through the sampling process for unbiased gradient estimation.

However, this leads to policy gradient-style estimators and should rely on elaborate policies such as reward hacking and single-step regularization to reduce gradient variance and stabilize training .

In comparison, a more straightforward approach is to omit the differentiation through the sampling process , where the sample ${\bm{y}}$ is directly plugged into the objective: $$\label{eq:offline} \small \ell({\bm{x}}, {\bm{\theta}}) \approx \sum_{j =1}^{|{\bm{y}}|+1} D({p(y|{\bm{x}}, {\bm{y}}{<j})} \Vert {q{\bm{\theta}}(y|{\bm{x}}, {\bm{y}}{<j})} ).$$ This way, various distance measures can be readily applied. Besides, the sampling becomes disentangled from the distance measure. i.e., we sample ${\bm{y}}$ from an arbitrary mixture of ${p}(\cdot|{\bm{x}})$ and ${q}\theta(\cdot|{\bm{x}})$ but use KL, RKL or JSD for estimating the distribution mis-alignment.

Intuitively, the samples from the teacher model are usually coherent, which may raise difficulties in fitting the small student model, while samples from the student model may be less structured or even meaningless. A workaround strategy is to trade off between them via mixed sampling , i.e., $y_j \sim \beta{p(\cdot|{\bm{x}}, {\bm{y}}{<j})} + (1-\beta) q{\bm{\theta}}(\cdot|{\bm{x}}, {\bm{y}}_{<j})$.

Online Knowledge Distillation

This section expands the application of knowledge distillation for speculative decoding in online environments. The approach enables improving the performance of draft model using results from speculative decoding, thus dynamically adapting to the query distribution and improving token acceptance rate. We also discuss the trade-off of our approach when integrating LLM serving systems.

Algorithm

Target LLM $p(\cdot|{\bm{x}})$, draft LLM $q_{\bm{\theta}}(\cdot|{\bm{x}})$, warmup dataset $\mathcal{D}$, online data stream $\mathcal{S}$, guess number $k$, temporary buffer $\mathcal{R}$, replay buffer $\mathcal{Q}$, update interval for the draft model $I$.

We depict our online speculative decoding algorithm (OSD) in [algo:1]. OSD begins by training the draft model using the warmup dataset (Line 2). The serving system then continuously handles incoming requests (as described in Lines 6 to 23). For each request, it uses standard speculative decoding (Lines 10-11) to generate responses until the ⟨EOS⟩ token. Concurrently, OSD tracks the token index ($error_index$) and target logits where the draft model proposes the wrong tokens (Line 15). Leveraging tracked information, OSD updates the draft model every $I$ iteration, with $I$ being a dynamically adjustable parameter. OSD updates the draft model with different loss functions (Line 20) as described in Section 4.1. The choice of loss function depends on the specific (draft, target) model pairs and the corresponding input data.

Discussion. OSD utilizes a replay buffer, $\mathcal{Q}$, to capture all pertinent information for updating the draft model. Various eviction policies can be employed to maintain a compact size for $\mathcal{Q}$. For example, one could opt to retain only the most informative pairs or the most recent entries. Similarly, users have the option to retain data in $\mathcal{Q}$ even after utilizing it to update the model multiple times. Determining the optimal eviction/retention strategy is a subject for future exploration. In the current study, we refrain from evicting any pairs and release $\mathcal{Q}$ after each model update. Furthermore, $I$ is a dynamic parameter. Depending on the system load and the rate at which the query distribution changes, users can adjust $I$ accordingly. For example, we can perform a gradient update opportunistically only when the service traffic is not on spike (i.e., spare flops are available). Overall, OSD continuously improves the draft model’s approximation (indicated by increased token acceptance rate $\alpha$) by learning from the target model during the serving phase. We next demonstrate how the enhanced acceptance rate directly contributes to a reduction in request latency.

Latency & Flops Analysis

Latency. As detailed in Appendix 7.2, compared with standard speculative decoding, the expected speedup for online speculative decoding is $\frac{1+\alpha_2+\alpha_2^2+...+\alpha_2^{k}}{1+\alpha_1+\alpha_1^2+...+\alpha_1^k}$. Based on the data from our experiment (refer to Table 1), when compared to standard speculative decoding, we expect a speedup improvement for Vicuna-7B (LLaMA-160M as the draft model) by factors of $2.42\times$, $1.43\times$, $1.64\times$, and $1.22\times$. Similarly, for Flan-T5-XL 3B (T5-small 80M as the draft model), the speedup enhancements are $3.06\times$, $1.76\times$, $2.72\times$, and $1.55\times$ across the four evaluated datasets.

FLOPs. (1) The FLOPs required to update the draft model are significantly fewer than those needed for inference on a large model. As elaborated in Appendix 7.3, for the two evaluated model pairs, the FLOPs ratio between the target model and the draft model is 18.75 for the pair (LLaMA-160M, Vicuna7B), and 12.6 for the pair (T5-small 80M, Flan-T5-XL 3B). (2) In practical systems, the FLOPs required for inference are significantly below the machine’s capacity. The Appendix 7.3 provides an analysis of Arena chatbot traces where the cluster’s computational utilization is under 1 percent. Given the above two observations, it becomes evident that the FLOPs spent on inference and updating the draft model are relatively insignificant when juxtaposed with the FLOPs consumed while operating the target model and the cluster’s total FLOPs.

Experiments

To assess the efficacy of our method, we initially evaluate its ability to improve the token acceptance rate ($\alpha$) within an offline context. This provides us with a theoretical upper bound on the performance improvements achievable when the query distribution remains constant. Subsequently, we examine the approach’s impact in an online environment, discovering that the acceptance rate improves even with a moderate amount of data while maintaining accuracy levels comparable to those in the offline scenario. Throughout our experiments, we employ two target models ($M_p)$: Vicuna-7B  and FLAN-T5-XL (3B) . Specifically for Vicuna-7B, we utilize LLaMA-160m  as the draft model ($M_q$). For FLAN-T5-XL, we use T5-Small  as the draft model. We evaluate performance across four diverse datasets: Text-to-SQL (Spider) , graduate school math (Gsm8k) , Python code generation (Code-search-Python) , and financial question answering (Alpaca-finance) . In all experiments, we set the number of proposed tokens to 5 for speculative decoding. For all online experiments, we fix the update interval $I$ at 8.

Offline Evaluation

In this section, we assess the efficacy of employing knowledge distillation to train a small model specifically for speculation in an offline environment. In such a setting, the speculative $M_q$ model has unrestricted access to the dataset, and the query distribution remains stable. To emulate these offline conditions, we distill the $M_q$ using the training dataset for two epochs and subsequently evaluate its performance by measuring the average token acceptance rate ($\alpha$) on the test set. As detailed in Section 4.1, we evaluated various sampling methods, namely teacher sampling, student sampling, and mix token-level sampling. Table 1 displays the token acceptance rate of the draft model for each method, using forward KL as the distance metric on the test dataset. For comparison, we also provide the acceptance rate for teacher-generated label fine-tuning and the original model.

For both the Vicuna-7B and FLAN-T5-XL models, the teacher sampling method outperforms others by achieving the highest acceptance rate. Furthermore, knowledge distillation has proven its efficacy in enhancing the draft model’s approximation, resulting in a high token acceptance rate. Intriguingly, we also find that fine-tuning with teacher-generated labels yields impressive performance on the Vicuna-7B model. Lastly, we experimented with different distance measurements like reverse KL and JSD. Nevertheless, these measurements either paralleled or underperformed when compared to forward KL. Such empirical findings underscore that the optimal distance measurement or sampling method varies depending on the task and model, and we leave to future work to find the best combination.

Model Task Original FT TF SF MixF
Vicuna-7B Spider 0.28 0.74 0.76 0.62 0.70
Gsm8k 0.58 0.74 0.75 0.67 0.73
Code-search-Python 0.38 0.65 0.65 0.51 0.61
Alpaca-finance 0.57 0.68 0.67 0.63 0.65
FLAN T5-XL Spider 0.13 0.33 0.78 0.67 0.70
Gsm8k 0.29 0.50 0.62 0.51 0.55
Code-search-Python 0.28 0.44 0.81 0.67 0.78
Alpaca-finance 0.39 0.56 0.63 0.59 0.60

Token acceptance rates ($\alpha$) after two epochs. FT: Finetuning on teacher-generated labels. TF, SF, MixF: Teacher, student, and mix token sampling respectively, all with forward KL.

Online Evaluation

Online Learning. First, we evaluate the effectiveness of our online algorithm by addressing two key questions: (1) Does the online algorithm increase the token acceptance rate? And is this enhancement comparable to the rates achieved in offline settings, which serve as an upper bound given their full access to data? (2) How quickly does the online algorithm increase the token acceptance rate, thereby indicating that the compact model has grasped the underlying distribution?

In our approach, we replicate the online serving process by iterating through the datasets, extracting prompts, and streaming generation requests. The system utilizes speculative decoding for each of these requests. Throughout this serving phase, we continually refine the speculative models, as detailed in Algorithm [algo:1]. For our baseline, we envision a scenario where the serving system has the capability to collect data offline in order to distill an initial draft model. This model is subsequently deployed online to cater to future requests. This process is simulated by using 10% of the dataset to distill the draft model, which remains static during online serving. For evaluation metrics, we calculate token acceptance rates averaged over the most recent 50 requests. This demonstrates $M_q$’s efficacy on the most current data.

Online acceptance rate (α) across different datasets. The x-axis represents the number of records that OSD has processed. Alpha is averaged over the most recent 50 records. Distribution Shift: Alpha is averaged over the most recent 100 records.

As depicted in Figure 2, both for Vicuna-7B and FLAN-T5, in the beginning, OSD yields a lower token acceptance rate in comparison to the offline distilled model. Nevertheless, these acceptance rates rise swiftly as the draft model is exposed to more data. We also annotate the token acceptance rate from the offline setting to highlight the potential peak performance that the online serving system could reach. In all instances, the online context can achieve comparable results. In some scenarios, OSD even surpasses the token acceptance rate of the offline test alphas. This discrepancy can be attributed to the fact that offline test alphas are assessed on the entire test dataset, whereas the online alphas represent the moving average of the latest 50 requests. It’s plausible that OSD performs optimally on specific data subsets, particularly if those subsets are more narrowly distributed than the complete dataset.

Distribution Shifts. We evaluate OSD’s ability to adapt to changes in data distribution. We detail the dataset preparation in Appendix [appendix:distribution-shift]. As illustrated in Figure 4, OSD’s alpha value dips notably at distribution boundaries, especially around 2K, 4K, and 6K records. This is anticipated since the draft model initially struggles when faced with a new distribution. However, the alpha value rebounds quickly as OSD processes more data, highlighting its adaptability to shifting query distributions.

We also compared our results to those from a static setting. To ensure the draft model wasn’t just memorizing data, we chose samples distinct from the online evaluation data. These samples correspond to 30%, 50%, 70%, and 100% of each dataset’s online evaluation volume, at 0.6K, 1K, 1.4K, and 2K quantities respectively. As depicted in Figure 4, upon an initial shift in query distribution, OSD’s performance aligns with or slightly trails the distillation with 30% data. However, it quickly catches up, matching or even surpassing performances seen with 70% to 100% data access. This highlights OSD’s ability to rival models fully exposed to the query distribution, even without intimate knowledge of the underlying query dynamics.

Real Workloads. We evaluate OSD on real LMSYS-chat conversations (Appendix  7.6) that span 4 months. First, we categorize conversations based on the language and we focus on conversations among the top five languages, excluding English. For every chosen language, we use an independent LLaMA-160M to serve as our draft model. All draft models share the same Vicuna-7B as the target model. The token acceptance rate, averaged over the latest 100 requests, showed in Figure 5, reveals that OSD’s enhances rates by 0.1 to 0.2, even with under 2K data points. Notably, Japanese was the easiest while Portuguese was the toughest. We also clustered English conversations by topics using the fine-tuned distilled Bert model , focusing on the top five. For topics with over 5K conversations, we sampled evenly to keep it within 5K. Figure 5 shows acceptance rates above 0.6 across topics, with Social and Computer discussions peaking near 0.8.

Chatbot Arena Conversations clustered by language and topic.

Precision and recall of high-frequency tokens. The x-axis shows the rating of the tokens based on their occurrence in the generated answers. For instance, token 1 appears most frequently in answers. Precision = # of times token i is accepted by the target model / # of times token i is proposed by the draft model. Recall = # of times token i is accepted by the target model / # of times token i appears in the final answer.

Qualitative Analysis

In this section, we conduct a comprehensive analysis to understand how our method enhances the token acceptance rate, and which tokens the draft model acquires across varying query distributions.

High-frequency tokens precision and recall. In our experiment using the Spider dataset, Vicuna-7M is the target model and LLaMA-160M the draft. We identify the top 100 tokens most frequently generated by the target model, which account for 72.2% of all appearances, following a power-law distribution. Figure 6 shows a marked improvement in both accuracy and recall of these tokens after distillation on the test dataset in an offline evaluation.

Dataset Spider Gsm8k Alpaca-Finance Code-Python
Tokens with the greatest precision increase AV, SELECT, first, ⟨EOS⟩, template, SUM, G, COUNT, \n, city, WHERE, ’;, (, IST, id ⟨EOS⟩, >>, +, To, <<, this, =, %, know, are, We, calculate, be, The, have 1, Here, (, :, provide, depends, However, goals, amount, 3, there, The, \n, personal, will ”’, (, Here, python, ’, how, doc, snippet, import, based, {, Python, This, :, you
Tokens with the greatest recall increase SELECT, *, FROM, (, IST, *), \n, COUNT, G, first, WHERE, ⟨EOS⟩, IN, ;, MAX, ’; start, >>, <<, +, find, how, we, =, fore, To, so, \ ⟨EOS⟩, then, let general, 1, several, This, depends, Here, provide, However, goals, over, (, If, amount, it, can Here, This, snippet, ”’, ’, how, python, (, takes, Python, you, doc, an, import, def

Top 15 tokens with the most recall/precision improvement across datasets. We ignore _ before tokens, which represents space in the LLaMA tokenizer.

Tokens learned across different datasets In our study, we analyze the top 10 tokens with the most pronounced accuracy and recall improvements across various datasets, focusing on the 100 most frequent tokens to understand the draft model’s learning trends. As detailed in Table 2, the improved tokens align well with the underlying data distribution. For example, in the Spider dataset, which frequently generates SQL statements, tokens like SELECT and WHERE have notably higher acceptance rates post-distillation. Similarly, in the Graduate Math dataset (Gsm8k), tokens such as <<, >>, =, and + stand out. These patterns highlight the draft model’s ability to adapt and predict tokens consistent with the data distribution.

Conclusion

Speculative decoding’s efficiently hinges on the draft model’s approximation to the target model. We introduce an online speculative method that continuously enhances the draft model based on varying data distributions. Experiments on both synthetic and real data demonstrate that online speculative decoding swiftly adapts to new data distributions, significantly enhancing token acceptance.

Appendix

Speedup of Speculative Decoding

As proved in  , compared with standard decoding, the expected improvement factor for offline speculative decoding is $\frac{1-\alpha^{k+1}}{(1-\alpha)(ck+1)}$. Let the time taken for a single run of $M_p$ be $T$. Define $c$, the cost coefficient, as the ratio of the time taken for a single run of $M_q$ to that of $M_p$. Each execution of lines 7 to 8 takes $Tck + T$ and, on average, yields $\frac{1-\alpha^{k+1}}{1-\alpha}$ tokens. As a result, the average time to produce one token using speculative decoding is given by $\frac{(ck+1)(1-\alpha)}{1-\alpha^{k+1}}T$. In contrast, the time to produce a single token using standard decoding is $T$. Hence, the wallclock time reduction of offline speculative decoding can be described as $\frac{1-\alpha^{k+1}}{(1-\alpha)(ck+1)}$.

Latency Analysis

Suppose OSD can improve the token acceptance rate from $\alpha_1$ to $\alpha_2$ and $T$ is the generation time for standard decoding. Based on Equation [eq:gen_len], this improvement leads to a decrease in the average generation time for each token, transitioning from $\frac{(ck+1)(1-\alpha_1)}{1-\alpha_{1}^{k+1}}T$ to $\frac{(ck+1)(1-\alpha_2)}{1-\alpha_{2}^{k+1}}T$. Consequently, this results in a speedup factor of $\frac{1-\alpha_2^{k+1}}{1-\alpha_1^{k+1}}\frac{1-\alpha_1}{1-\alpha_2} = \frac{1+\alpha_2+\alpha_2^2+...+\alpha_2^{k}}{1+\alpha_1+\alpha_1^2+...+\alpha_1^k}$ compared to standard speculative decoding.

In the aforementioned analysis, we omitted the additional latency due to updating the smaller model for the following reasons: (1) As illustrated subsequently, the additional computational cost (FLOPs) from the update remains marginal when juxtaposed with the computational demands of running the larger model. (2) Updates are periodic, during times of moderate request loads, the latency for serving individual requests remains largely unaffected. Additionally, given that the update operation for the smaller model is considerably less resource-intensive than inference, the associated latency might be seamlessly masked, rendering it virtually imperceptible. Lastly, the processes of updating and inference can even be executed concurrently on separate devices.

Flops Analysis

The FLOPs required to update the draft model are significantly fewer than those needed for inference on a large model. Denote $L$ as the average length of the generated sequence. For each verification, the draft model suggests $k$ tokens. The expected length for a single run of the target LLM, denoted as $a$, can be calculated using Equation [eq:gen_len]. Therefore, OSD undergoes the verification process $\frac{L}{a}$ times, with each time verifying $k+1$ tokens. We use $F_{qfwd}$ to represent the arithmetic operations required by a singular forward run of the draft model for each token, and $F_{pfwd}$ stands for the FLOPs needed for a single forward run of the target model per token. Therefore, the computational demand (in FLOPs) for the draft and teacher models to handle one request can be expressed as: $\text{FLOPs}(draft) = \frac{L}{a} \times k \times F_{qfwd}, \text{FLOPs}(target) = \frac{L}{a} \times (k+1) \times F_{pfwd}.$ Let’s consider the FLOPs required to update the student model per token as $F_{qbwd}$. The cumulative FLOPs necessary to process $I$ requests is given by: $$\frac{LI}{a} \times \left[k \times F_{qfwd} + (k+1) \times F_{pfwd}\right] + I \times L \times F_{qbwd}.$$ Based on the findings of , training is approximately three times costlier than inference. This translates to roughly 6 FLOPs per parameter for training on a single token and 2 FLOPs per parameter for inferring on one token. Thus, we can simplify the total FLOPs expression to: $$\frac{LI}{a}\left[(k + 3a) \times F_{qfwd} + (k+1) \times F_{pfwd}\right].$$

The proportion of FLOPs needed to run the target model to that of the draft model is given by: $$\frac{(k+1)\times F_{pfwd}}{(k+3a)\times F_{qfwd}}.$$ For the two model pairs evaluated, assuming an average of 5 proposed tokens per run: (1) (LLaMA-160M, Vicuna7B) with an average acceptance rate of 0.71, the ratio is approximately $\frac{(5+1) \times 7B}{(5+3 \times 3) \times 160M} = 18.75$. (2) (T5-small 80M, Flan-T5-XL 3B), with an average acceptance rate of 0.76, the ratio is roughly $\frac{(5+1) \times 3B}{(5+3 \times 4.3) \times 80M} = 12.6$.

In practical systems, the FLOPs required for inference are significantly below the machine’s capacity. Consider the LMSYS-Chat-1M . It comprises traces spanning 125 days with 1000,000 requests, averaging less than 2,000 tokens per request (including both prompts and responses). When serving a 30B model with 8 A100 GPUs, the FLOPs consumed per second can be estimated as (Still, we estimate 2 FLOPs per token per parameter): $$\frac{2000 \times 1000,000}{125 \times 24 \times 3600} \times 30 \times 10^9 \times 2 = 5.5 \times 10^9 \text{ FLOPs or 5.5 GFLOPs}$$ On the other hand, 8 A100 GPUs offer a combined capacity of $8 \times 312 \text{ TFLOPs}$, and the computational utilization is notably low. While Arena (the platform that generates LMSYS-Chat-1M) may not be the most efficient and might lack substantial traffic, it’s the only publicly accessible LLM service trace. Even after amplifying the load multiple times, based on the above calculations, the computation efficiency remains limited.

Data Mix

Moreover, there is a question of whether the draft model, once adapted to the new distribution, might lose its prior knowledge. To probe this, we conducted an experiment mixing 2k prompts each from the Gsm8k and Alpaca-finance datasets. During online serving, for the initial 2k requests, we only update the model based on data from the Gsm8k dataset. For the subsequent half of the requests, we restrict updates solely to data from the Alpaca-finance dataset. We then provide the average token acceptance rates for all requests, segmented by their data source (Gsm8k versus Alpaca-finance). As depicted in Figure 7, the token acceptance rate for Gsm8k increases as the draft model is exposed to more data. Conversely, the acceptance rate ($\alpha$) for the Alpaca-finance dataset remains consistent. This is anticipated since we only update the draft model using Gsm8k data. In the latter half of the dataset, the token acceptance rate for the Alpaca-finance dataset also shows an uptrend. Intriguingly, the rate for Gsm8k remains consistent, suggesting that the draft model retains its learned knowledge without showing signs of forgetting.


Mix of distributions.

Data Preparation for Distribution Shift Analysi

To emulate this shift in distribution, we select 2k prompts from each dataset under evaluation. T he data from the four datasets are amalgamated by direct concatenation, such that the records from $i\times2k$ to $(i+1)\times2k$ belong solely to dataset $i$.

Arena Dataset

For expedited experimental evaluation, we randomly sample a subset with 10K records from LMSYS-Chat-1M , a comprehensive real-world LLM conversation dataset. This dataset encompasses interactions with 25 models spanning from April to August 2023 and features conversations in over 150 languages. For all experiments, we only pick conversations for Vicuna models.

Footnotes

  1. ${{\bm{y}}}{<i}$ refers to ${ y_j}{j=1}^{i-1}$.

root[102] (1:1-879:1, 0-48901)
├─0 thematicBreak (1:1-1:4, 0-3)
├─1 paragraph[3] (2:1-27:8, 4-1653)
│ ├─0 text "abstract: |\nSpeculative decoding is a pivotal technique to accelerate the\ninference of large language models (LLMs) by employing a smaller draft\nmodel to predict the target model’s outputs. However, its efficacy can\nbe limited due to the low predictive accuracy of the draft model,\nparticularly when faced with diverse text inputs and a significant\ncapability gap between the draft and target models. We introduce\nonline speculative decoding to address this challenge. The main idea\nis to continually update (multiple) draft model(s) on observed user\nquery data using the abundant excess computational power in an LLM\nserving cluster. Given that LLM inference is memory-bounded, the\nsurplus computational power in a typical LLM serving cluster can be\nrepurposed for online retraining of draft models, thereby making the\ntraining cost-neutral. Since the query distribution of an LLM service\nis relatively simple, retraining on query distribution enables the\ndraft model to more accurately predict the target model’s outputs,\nparticularly on data originating from query distributions. As the\ndraft model evolves online, it aligns with the query distribution in\nreal time, mitigating distribution shifts. We develop a prototype of\nonline speculative decoding based on online knowledge distillation and\nevaluate it using both synthetic and real query data on several\npopular LLMs. The results show a substantial increase in the token\nacceptance rate by 0.1 to 0.65, which translates into 1.22$\\times$ to\n3.06$\\times$ latency reduction. Code is available at\n" (2:1-26:1, 4-1603)
│ ├─1 inlineCode "https://github.com/LiuXiaoxuanPKU/OSD" (26:3-26:42, 1605-1644)
│ └─2 text ".\nauthor:" (26:42-27:8, 1644-1653)
├─2 list[2] (28:1-46:35, 1654-2187)
│ │ ordered: false
│ │ start: null
│ │ spread: false
│ ├─0 listItem[5] (28:1-44:14, 1654-2126)
│ │ │ spread: true
│ │ │ checked: null
│ │ ├─0 paragraph[2] (28:3-30:44, 1656-1771)
│ │ │ ├─0 text "|\n" (28:3-29:1, 1656-1658)
│ │ │ └─1 strong[1] (29:3-30:42, 1660-1769)
│ │ │ └─0 text "Xiaoxuan Liu$,:$$^{1}$ $\\qquad$ Lanxiang Hu$^{2}$$\\qquad$ Peter\nBailis$^{3}$$\\qquad$ Ion Stoica$^{1}$" (29:5-30:40, 1662-1767)
│ │ ├─1 paragraph[2] (33:3-34:44, 1784-1893)
│ │ │ ├─0 emphasis[2] (33:3-34:41, 1784-1890)
│ │ │ │ ├─0 emphasis[1] (33:4-34:38, 1785-1887)
│ │ │ │ │ └─0 text "Zhijie Deng$^{4}$$\\thanks{Corresponding author}$ $\\qquad$ Alvin\nCheung$^{1}$$\\qquad$ Hao Zhang$^{2" (33:5-34:37, 1786-1886)
│ │ │ │ └─1 text "}$" (34:38-34:40, 1887-1889)
│ │ │ └─1 text "*" (34:41-34:42, 1890-1891)
│ │ ├─2 paragraph[1] (37:3-37:61, 1906-1964)
│ │ │ └─0 text "$^{1}$ UC Berkeley$^{2}$ UCSD$^{3}$ Sisu Data$^{4}$ SJTU" (37:3-37:59, 1906-1962)
│ │ ├─3 paragraph[1] (40:3-40:55, 1977-2029)
│ │ │ └─0 inlineCode "{xiaoxuanliu, istoica, akcheung}@cs.berkeley.edu" (40:3-40:53, 1977-2027)
│ │ └─4 paragraph[2] (43:3-44:14, 2042-2126)
│ │ ├─0 inlineCode "{lah003, haozhang}@ucsd.edu, peter@sisudata.com, zhijied@sjtu.edu.cn" (43:3-43:73, 2042-2112)
│ │ └─1 text "\nbibliography:" (43:73-44:14, 2112-2126)
│ └─1 listItem[1] (45:1-46:35, 2127-2187)
│ │ spread: false
│ │ checked: null
│ └─0 paragraph[1] (45:3-46:35, 2129-2187)
│ └─0 text "iclr2024_conference.bib\ntitle: Online Speculative Decoding" (45:3-46:35, 2129-2187)
├─3 thematicBreak (47:1-47:4, 2188-2191)
├─4 heading[1] (49:1-49:15, 2193-2207)
│ │ depth: 1
│ └─0 text "Introduction" (49:3-49:15, 2195-2207)
├─5 paragraph[1] (51:1-56:61, 2209-2622)
│ └─0 text "Large language models (LLMs) such as GPT-4 , Claude , and Llama  are\nrapidly reinventing today’s applications. Many companies are racing to\ndeploy LLMs in their vertical domains, such as search, chatbots, and\nvirtual assistants. Since most of these applications demand low latency,\noptimizing LLM serving latency is of vital importance and can directly\ntranslate into better quality of service and cost reduction." (51:1-56:61, 2209-2622)
├─6 paragraph[1] (58:1-72:9, 2624-3598)
│ └─0 text "The latency of today’s LLM service is unfortunately very high. This is\nprimarily because serving a user query requires multiple serial\nevaluations of the LLM, each generating only one token of the response.\nAn emerging solution to reduce the latency is speculative decoding.\nSpeculative decoding employs a smaller model to speculate multiple\noutput tokens of the target (large) model, then lets the target LLM\nverify these speculations in parallel. Then, if the verification of a\ntoken fails, the large model must recompute from that point. Therefore,\nthe performance of speculative decoding primarily depends on the\nspeculation accuracy of the small model. In the presence of diverse text\ninputs, the accuracy of existing speculative decoding methods is\nunfortunately not very high, due to the capability gap between the draft\nand target model. Employing a larger, more accurate model however\ndefeats the purpose of speculative decoding as it potentially increases\nlatency." (58:1-72:9, 2624-3598)
├─7 paragraph[15] (74:1-105:46, 3600-5745)
│ ├─0 text "To address this challenge, we introduce a novel method, " (74:1-74:57, 3600-3656)
│ ├─1 emphasis[1] (74:57-75:22, 3656-3685)
│ │ └─0 text "online\nspeculative decoding" (74:58-75:21, 3657-3684)
│ ├─2 text ", specifically designed for online LLM services.\nThe method leverages the abundant redundant compute, termed as “spare\nflops,” available in a typical LLM serving cluster to continuously\nretrain (multiple) small draft models through online learning on query\ndata posted to the LLM service. Our approach is simple and offers\nseveral significant advantages. First, user queries to a specific LLM\nservice often exhibit a common domain-specific distribution , reflecting\nshared usage patterns. While accurately speculating the larger model’s\noutputs on " (75:22-83:12, 3685-4233)
│ ├─3 emphasis[1] (83:12-83:31, 4233-4252)
│ │ └─0 text "any diverse input" (83:13-83:30, 4234-4251)
│ ├─4 text " is challenging, it is feasible to enhance\nthe draft model’s prediction accuracy, " (83:31-84:40, 4252-4334)
│ ├─5 emphasis[1] (84:40-85:16, 4334-4381)
│ │ └─0 text "only for similar inputs posted\nto the service" (84:41-85:15, 4335-4380)
│ ├─6 text ", characterized by the query distribution. This can be\nachieved by finetuning the draft model on user query distribution or\nfinetuning multiple draft models, each on a cluster of the query\ndistribution, and selecting the appropriately specialized draft model to\nspeculate based on the class of inputs they are trained on. As shown\nin §" (85:16-90:5, 4381-4716)
│ ├─7 html "<a href=\"#sec:eval:online_evaluation\" data-reference-type=\"ref\"\ndata-reference=\"sec:eval:online_evaluation\">" (90:5-91:45, 4716-4824)
│ ├─8 text "5.2" (91:45-91:48, 4824-4827)
│ ├─9 html "</a>" (91:48-91:52, 4827-4831)
│ ├─10 text ", we show that it is\npossible to train multiple draft models, each for a different language\nor topic. Second, the primary bottleneck for transformer-based LLM\ninference is the accelerator’s memory bandwidth, as generating each word\nrequires loading the model weights from HBM to SRAM as well as reading\nthe KV cache on all previous words. This results in a substantial amount\nof unused compute, especially during non-spike traffic hours , in an LLM\nserving cluster. We demonstrate that these spare FLOPs can be\neffectively repurposed for online retraining of draft models, with\ninconspicuous retraining cost\n(§" (91:52-101:3, 4831-5441)
│ ├─11 html "<a href=\"#sec:analysis\" data-reference-type=\"ref\"\ndata-reference=\"sec:analysis\">" (101:3-102:31, 5441-5521)
│ ├─12 text "4.2.2" (102:31-102:36, 5521-5526)
│ ├─13 html "</a>" (102:36-102:40, 5526-5530)
│ └─14 text "). Third, since tuning is\nperformed online, the draft models continuously evolve over time based\non the observed query data, which ensures high speculation accuracy even\nwhen faced with shifts in query distribution." (102:40-105:46, 5530-5745)
├─8 paragraph[1] (107:1-115:48, 5747-6351)
│ └─0 text "Based on these insights, we develop an online speculative decoding\nframework to improve the efficiency of online LLM serving. To align the\ndraft model with the target model on a newly observed user query, we\ndevelop a new online learning algorithm based on Generalized Knowledge\nDistillation (GKD) . The algorithm keeps track of the recent queries\nthat the draft model has speculated incorrectly, and forces the draft\nmodel to emulate the target model’s outputs on these queries. The\nalgorithm performs GKD-based gradient update opportunistically only when\nspare flops are available, hiding the overhead." (107:1-115:48, 5747-6351)
├─9 html "<figure id=\"fig:arch\">\n<embed src=\"figures/arch.pdf\" />\n<figcaption>Online speculative decoding overview. For each prompt, the\ndraft model suggests multiple tokens in a single step. The target model\nthen verifies these tokens, accepting some and rejecting others. If the\nstudent proposes incorrect tokens, both the draft and target\ndistributions are stored in a buffer. Once the buffer exceeds a\nspecified threshold, the draft model is updated by calculating the loss\nbetween the draft and target distributions using various distance\nmetrics.</figcaption>\n</figure>" (117:1-127:10, 6353-6918)
├─10 paragraph[1] (129:1-129:58, 6920-6977)
│ └─0 text "In summary, this paper makes the following contributions:" (129:1-129:58, 6920-6977)
├─11 list[3] (131:1-145:10, 6979-7784)
│ │ ordered: false
│ │ start: null
│ │ spread: true
│ ├─0 listItem[1] (131:1-133:26, 6979-7147)
│ │ │ spread: false
│ │ │ checked: null
│ │ └─0 paragraph[1] (131:3-133:26, 6981-7147)
│ │ └─0 text "We introduce online speculative decoding to reduce LLM serving latency\nby adapting (multiple) draft models on the fly using query data and\nknowledge distillation." (131:3-133:26, 6981-7147)
│ ├─1 listItem[1] (135:1-137:67, 7149-7350)
│ │ │ spread: false
│ │ │ checked: null
│ │ └─0 paragraph[1] (135:3-137:67, 7151-7350)
│ │ └─0 text "We explore various GKD methods for constructing draft models and\nidentify the most effective variants, suggesting them as superior\nalternatives to existing finetuning methods in offline settings." (135:3-137:67, 7151-7350)
│ └─2 listItem[1] (139:1-145:10, 7352-7784)
│ │ spread: false
│ │ checked: null
│ └─0 paragraph[1] (139:3-145:10, 7354-7784)
│ └─0 text "Our method demonstrates a significant improvement in token acceptance\nrate by 10-65% on diverse datasets, translating to 1.2-3.1$\\times$\nreduction in latency theoretically, with a negligible additional cost.\nIt surpasses existing methods which construct static draft models\nusing fine-tuning or distillation on offline datasets, and matches the\nhypothetical accuracy achieved if all query data were available a\npriori." (139:3-145:10, 7354-7784)
├─12 heading[1] (147:1-147:15, 7786-7800)
│ │ depth: 1
│ └─0 text "Related Work" (147:3-147:15, 7788-7800)
├─13 paragraph[1] (149:1-153:70, 7802-8151)
│ └─0 text "LLMs have become pervasive in today’s AI applications, underscoring the\nimportance of optimizing LLM inference. Numerous system optimizations\nhave been developed to optimize the throughput of LLM serving . This\npaper particularly concentrates on a significant strand of research,\nspeculative decoding, aimed at reducing the latency of LLM inference." (149:1-153:70, 7802-8151)
├─14 paragraph[2] (155:1-173:28, 8153-9433)
│ ├─0 strong[1] (155:1-155:26, 8153-8178)
│ │ └─0 text "Speculative decoding." (155:3-155:24, 8155-8176)
│ └─1 text " Speculative decoding  accelerates LLM decoding\nby employing a (small) draft model to predict the outputs of the larger\ntarget model, which are then verified by the target model. Typically,\nthe draft model, while having fewer parameters, is pretrained using the\nsame training data as the target mode, resulting in a negotiable\ninference cost but with compromised capability. If the draft model can\ncorrectly predict more than one token per verification step, the memory\nI/O for accessing the model weights and KV cache at inference is\namortized across multiple output tokens, thereby reduces latency,\nespecially since LLM inference is often constrained by GPU HBM\nbandwidth. The efficacy of speculative decoding largely hinges on the\ndraft model’s ability to accurately predict the target model’s outputs.\nExisting work improves the speculation accuracy by using multiple\ncollectively boosted  or staged  draft models, or retraining the target\nmodel with auxiliary prediction heads as a draft model . These methods\npredominantly assume a static draft model post-deployment. In contrast,\nour work introduces a framework that actively adapts the draft model to\nthe evolving user query distribution on the fly, irrespective of the\ndraft model’s construction." (155:26-173:28, 8178-9433)
├─15 paragraph[2] (175:1-187:10, 9435-10295)
│ ├─0 strong[1] (175:1-175:45, 9435-9479)
│ │ └─0 text "Distillation for auto-regressive models." (175:3-175:43, 9437-9477)
│ └─1 text " Knowledge distillation (KD)\nis a framework to generate smaller models that emulate the performance\nof larger models. However, KD in its conventional form has been observed\nto be less effective for LLMs. extend KD to autoregressive LLMs by\ndecoding from the student model and optimizing the reserve KL divergence\nbetween students and teachers. Further, introduce generalized knowledge\ndistillation (GKD) to optimize a linear combination of the forward KL\nand reverse KL between teacher and student, using a blend of teacher-\nand student-sampled data. Drawing inspiration from both works, our paper\napplies KD to speculative decoding for LLMs. We empirically determine\nthe most effective KD variant for maximizing the draft model’s accuracy,\nand extend it to dynamically generate draft models for online LLM\nservices." (175:45-187:10, 9479-10295)
├─16 heading[1] (189:1-189:13, 10297-10309)
│ │ depth: 1
│ └─0 text "Background" (189:3-189:13, 10299-10309)
├─17 paragraph[1] (191:1-198:46, 10311-10828)
│ └─0 text "We first briefly review speculative decoding , a critical technique that\naccelerates inference of a large target LLM $p(\\cdot|{\\bm{x}})$ with\ntoken proposals from a small draft model\n$q_{\\bm{\\theta}}(\\cdot|{\\bm{x}})$. ${\\bm{x}}$ denotes the concatenation\nof the input prompt and already generated tokens. The two distributions\nare both auto-regressive. We emphasize the parameters ${\\bm{\\theta}}$ of\nthe draft model because we usually need to tailor them according to the\ntarget LLM for more substantial acceleration." (191:1-198:46, 10311-10828)
├─18 paragraph[13] (200:1-215:32, 10830-11813)
│ ├─0 text "Speculative decoding uses a (small) draft model to propose $k$ tokens\n${{\\bm{y}}} \\triangleq { y_i}" (200:1-201:32, 10830-10931)
│ ├─1 emphasis[1] (201:32-201:48, 10931-10947)
│ │ └─0 text "{i=1}^k \\sim q" (201:33-201:47, 10932-10946)
│ ├─2 text "{\\bm{\\theta}}(\\cdot | {\\bm{x}})$,\nand let the target LLM estimate the $k+1$ probabilities,\n${p({y}|{\\bm{x}}, {{\\bm{y}}}" (201:48-203:30, 10947-11067)
│ ├─3 emphasis[1] (203:30-203:39, 11067-11076)
│ │ └─0 text "{<i})}" (203:31-203:38, 11068-11075)
│ ├─4 text "{i=1}^{k+1}$[^1], in parallel.\nWith $i$ rising from $1$ to $k$, speculative decoding accepts the\nproposal ${y}" (203:39-205:14, 11076-11186)
│ ├─5 emphasis[1] (205:14-206:37, 11186-11229)
│ │ └─0 text "i$ if\n$u \\leq p(y_i|{\\bm{x}}, {{\\bm{y}}}" (205:15-206:36, 11187-11228)
│ ├─6 text "{<i}) / q_{\\bm{\\theta}}({y}" (206:37-206:64, 11229-11256)
│ ├─7 emphasis[1] (206:64-206:88, 11256-11280)
│ │ └─0 text "i|{\\bm{x}}, {{\\bm{y}}}" (206:65-206:87, 11257-11279)
│ ├─8 text "{<i})$\nwhere $u \\sim U[0,1]$; otherwise exits. Let $a$ denote the number of\naccepted tokens, which takes values in ${0,\\dots, k}$. We can sample\nan additional token ${y}" (206:88-209:25, 11280-11451)
│ ├─9 emphasis[1] (209:25-212:32, 11451-11550)
│ │ └─0 text "{a+1}$ from the following distribution\n$$p'(y) =\n\\begin{cases}\np(y|{\\bm{x}}, {{\\bm{y}}}" (209:26-212:31, 11452-11549)
│ ├─10 text "{<a+1}) & \\text{if $a = k$}\\\n\\mathrm{norm}(\\max(0, p(y|{\\bm{x}}, {{\\bm{y}}}" (212:32-213:53, 11550-11632)
│ ├─11 emphasis[1] (213:53-213:66, 11632-11645)
│ │ └─0 text "{<a+1}) - q" (213:54-213:65, 11633-11644)
│ └─12 text "{\\bm{\\theta}}(y|{\\bm{x}}, {{\\bm{y}}}_{<a+1}))) & \\text{otherwise}\n\\end{cases}$$ where $\\mathrm{norm}(\\cdot)$ makes the probabilities\nover the vocabulary sum to $1$." (213:66-215:32, 11645-11813)
├─19 paragraph[3] (217:1-225:32, 11815-12393)
│ ├─0 text "Prior work has shown that the resulting samples\n$\\tilde{{\\bm{y}}} \\triangleq {{y}" (217:1-218:35, 11815-11897)
│ ├─1 emphasis[1] (218:35-218:48, 11897-11910)
│ │ └─0 text "1, \\dots, y" (218:36-218:47, 11898-11909)
│ └─2 text "{a+1}}$ strictly follow\nthe distribution of the target LLM $p(\\cdot|{\\bm{x}})$ . We concatenate\n$\\tilde{{\\bm{y}}}$ to ${\\bm{x}}$ and repeat the above process until\nmeeting ⟨EOS⟩. Each run of the target LLM generates $a+1$ tokens with\n$a\\geq0$. This ensures that at least one new token is generated even in\nthe worst case. The generation process can be significantly accelerated\nif the draft LLM better approximates the target one, particularly $a$ is\nlarger for each target LLM run." (218:48-225:32, 11910-12393)
├─20 paragraph[2] (227:1-234:57, 12395-12911)
│ ├─0 strong[1] (227:1-227:40, 12395-12434)
│ │ └─0 text "Expected acceptance rate & speedup." (227:3-227:38, 12397-12432)
│ └─1 text " The acceptance rate, denoted as\n$\\alpha$, serves as a measure of how closely the draft model\napproximates the target model. It is defined as the expected probability\nthat speculative decoding will accept a proposal token given the prompt\n$y_i \\sim q_{\\bm{\\theta}}(y_i|{\\bm{x}}, {{\\bm{y}}}_{<i})$. This rate\ndirectly influences the expected length\n($\\mathbb{E}(|\\tilde{{\\bm{y}}}|)$) of $\\tilde{{\\bm{y}}}$ for each target\nLLM run and the speedup brought by speculative decoding." (227:40-234:57, 12434-12911)
├─21 paragraph[5] (236:1-246:49, 12913-13684)
│ ├─0 text "Assuming that the $k + 1$ simultaneous evaluations of the target LLM $p$\ntake roughly the same amount of time as generating a single token in\nparallel, let $c$ be the time ratio for a single run between\n$q_{\\bm{\\theta}}$ and $p$. The expected generation length of a single\ntarget LLM run and the speedup in the total wall time due to speculative\ndecoding is represented as : $$\\label{eq:gen_len}\n\\mathbb{E}(|\\tilde{{\\bm{y}}}|) = \\frac{1 - \\alpha^{k+1}}{1-\\alpha},\\quad \\mathbb{E}(speedup)=\\frac{1-\\alpha^{k+1}}{(1-\\alpha)(kc+1)}.$$\nWe depict the speedup for varying values of $\\alpha$ in\nFigure " (236:1-244:8, 12913-13512)
│ ├─1 html "<a href=\"#fig:analysis-alphas\" data-reference-type=\"ref\"\ndata-reference=\"fig:analysis-alphas\">" (244:8-245:38, 13512-13606)
│ ├─2 text "2" (245:38-245:39, 13606-13607)
│ ├─3 html "</a>" (245:39-245:43, 13607-13611)
│ └─4 text ", which demonstrates the\nimportance of $\\alpha$ in affecting the speedup." (245:43-246:49, 13611-13684)
├─22 html "<figure id=\"fig:analysis-alphas\">\n<p><embed src=\"figures/analysis_k.pdf\" /> <embed\nsrc=\"figures/analysis_c.pdf\" /></p>\n<figcaption>Speculative decoding speedups for varying values of <span\nclass=\"math inline\"><em>α</em></span> in Figure <a\nhref=\"#fig:analysis-alphas\" data-reference-type=\"ref\"\ndata-reference=\"fig:analysis-alphas\">2</a>. For smaller <span\nclass=\"math inline\"><em>α</em></span> values, speculative decoding may\neven degrade performance (indicated by a speedup <span\nclass=\"math inline\"> &lt; 1</span>), particularly when the draft model\nis sizeable. Furthermore, the relationship between speedup and <span\nclass=\"math inline\"><em>α</em></span> is superlinear; doubling the\nacceptance rate can yield a speedup exceeding 2<span\nclass=\"math inline\">×</span>.</figcaption>\n</figure>" (248:1-262:10, 13686-14480)
├─23 paragraph[4] (264:1-279:10, 14482-15543)
│ ├─0 strong[1] (264:1-264:17, 14482-14498)
│ │ └─0 text "Observation." (264:3-264:15, 14484-14496)
│ ├─1 text " Interestingly, we can actually enhance $\\alpha$ based\non a key observation: the speculative decoding process inherently\nidentifies the inaccuracies of the small draft LLM and offers correct\nsolutions for these inaccuracies. This essentially means that we receive\nvaluable insights on the areas and strategies to refine the draft model\nat no additional cost. Viewed through the lens of online learning, we\ncan effortlessly accumulate a set of input-output pairs, denoted as\n$([{\\bm{x}}, {\\bm{y}}" (264:17-271:22, 14498-14993)
│ ├─2 emphasis[1] (271:22-271:57, 14993-15028)
│ │ └─0 text "{<a+1}], p(y|{\\bm{x}}, {{\\bm{y}}}" (271:23-271:56, 14994-15027)
│ └─3 text "{<a+1}))$, that\nhave yet to be assimilated by the draft LLM, paving the way for its\nsubsequent optimization. Given the reduced size of the draft model (for\ninstance, it may be over $20\\times$ smaller than the target model), its\ntuning is not only efficient but also viable for real-time online\nadjustments. Prior work  has primarily approached speculative decoding\nin an offline manner, meaning the draft model remains static during\nonline deployment. We next develop online speculative decoding to bridge\nthis gap." (271:57-279:10, 15028-15543)
├─24 heading[1] (281:1-281:30, 15545-15574)
│ │ depth: 1
│ └─0 text "Online Speculative Decoding" (281:3-281:30, 15547-15574)
├─25 paragraph[1] (283:1-288:32, 15576-15962)
│ └─0 text "We propose the online speculative decoding approach to update the draft\nmodel dynamically for more effective suggestions. We frame the learning\nproblem based on the aforementioned auxiliary information as online\nknowledge distillation, where the teacher and student models correspond\nto the target and draft LLMs in speculative decoding, respectively. We\nelaborate on the details below." (283:1-288:32, 15576-15962)
├─26 heading[1] (290:1-290:51, 15964-16014)
│ │ depth: 2
│ └─0 text "Knowledge Distillation for Speculative Decoding" (290:4-290:51, 15967-16014)
├─27 paragraph[1] (292:1-311:65, 16016-17444)
│ └─0 text "Knowledge distillation is a general framework to align the predictive\ndistribution of a small model (i.e., student model) with that of a\nlarger one (i.e., teacher model). Prior research has utilized knowledge\ndistillation to compress neural networks, resulting in decreased\ninference costs and memory requirements. We posit that knowledge\ndistillation is highly effective for speculative decoding. In this\napproach, the draft model acts as the student and the target model\nserves as the teacher. During speculative decoding, we possess complete\ninformation on both the proposed and verified probabilities of each\ntoken. This information helps to construct objectives for distilling the\ndraft model, aligning its output distributions with those of the target\nmodel and thereby improving the token acceptance rate of the draft\nmodel. The distillation loss generally takes the form of:\n$$\\label{eq:distill}\n\\small\n\\begin{aligned}\n\\ell({\\bm{\\theta}}) &= \\frac{1}{n_B}\\sum_{{\\bm{x}}^{(i)} \\in \\mathcal{B}} \\ell({\\bm{x}}^{(i)}, {\\bm{\\theta}}), \\quad \\ell({\\bm{x}}, {\\bm{\\theta}}) = D ({p(\\cdot|{\\bm{x}})} \\Vert {q_{\\bm{\\theta}}(\\cdot|{\\bm{x}})} ),% \\\n% &= \\frac{1}{n_B}\\sum_{\\vx \\in \\mathcal{B}} \\sum_{t=1} \\KL(q_\\vtheta(y_t|\\vx, \\vy_{<t}) \\Vert p(y_t|\\vx, \\vy_{<t})) \\\n\\end{aligned}$$ where $\\mathcal{B} = {{\\bm{x}}^{(i)}}_{i=1}^{n_B}$\ndenotes a batch of inputs and $D$ denotes some distance measure." (292:1-311:65, 16016-17444)
├─28 paragraph[10] (313:1-338:44, 17446-19459)
│ ├─0 strong[1] (313:1-313:22, 17446-17467)
│ │ └─0 text "Distance measure." (313:3-313:20, 17448-17465)
│ ├─1 text " In the case of auto-regressive models, the\nprediction distribution is categorical at each token. Often, we can\naugment the predicted logits with a tunable temperature $\\tau$ for\nsoftmax transformation. We then use the popular forward KL and reverse\nKL (RKL), as well as their mixture (i.e., the JSD divergence) to\ninstantiate $D$ : $$\\small\n\\begin{aligned}\n&\\ell_{KL}({\\bm{x}}, {\\bm{\\theta}}) = D_{\\mathrm{KL}}( {p(\\cdot|{\\bm{x}})}\\Vert {q_{\\bm{\\theta}}(\\cdot|{\\bm{x}})}), \\\n&\\ell_{RKL}({\\bm{x}}, {\\bm{\\theta}}) = D_{\\mathrm{KL}}({q_{\\bm{\\theta}}(\\cdot|{\\bm{x}})} \\Vert {p(\\cdot|{\\bm{x}})}), \\\n&\\ell_{{JSD}[\\beta]} ({\\bm{x}}, {\\bm{\\theta}}) = \\beta D_{\\mathrm{KL}}\\left({p(\\cdot|{\\bm{x}})} \\big\\Vert {p}^\\beta_{\\bm{\\theta}}(\\cdot|{\\bm{x}})\\right)+ (1-\\beta) D_{\\mathrm{KL}}\\left({q_{\\bm{\\theta}}(\\cdot|{\\bm{x}})} \\big\\Vert {p}^\\beta_{\\bm{\\theta}}(\\cdot|{\\bm{x}})\\right),\n\\end{aligned}$$ where\n${p}^\\beta_{\\bm{\\theta}}(\\cdot|{\\bm{x}}) \\triangleq \\beta{p(\\cdot|{\\bm{x}})} + (1-\\beta){q_{\\bm{\\theta}}(\\cdot|{\\bm{x}})}$.\nThese objectives diverge from the conventionally used label-based\nfine-tuning objectives in speculative decoding, as highlighted in . As\nshown in Section " (313:22-327:18, 17467-18673)
│ ├─2 html "<a href=\"#sec:offline-eval\" data-reference-type=\"ref\"\ndata-reference=\"sec:offline-eval\">" (327:18-328:35, 18673-18761)
│ ├─3 text "5.1" (328:35-328:38, 18761-18764)
│ ├─4 html "</a>" (328:38-328:42, 18764-18768)
│ ├─5 text ", objectives based on the KL\ndivergence prove to be more effective. This is because distributions\nconvey richer information than mere labels, thereby enhancing their\ncapability to guide the student model . Additionally, these objectives\nenhance convergence rates  and bolster calibration. The reverse KL is\nhighlighted for its mode-seeking behavior, offering unique advantages .\nIn our study, and in alignment with previous research , we empirically\ndetermine that the optimal distance measure can vary depending on the\ntasks and the relative capacities of the teacher and student models (see\n§" (328:42-337:2, 18768-19362)
│ ├─6 html "<a href=\"#sec:offline-eval\" data-reference-type=\"ref\"\ndata-reference=\"sec:offline-eval\">" (337:2-338:35, 19362-19450)
│ ├─7 text "5.1" (338:35-338:38, 19450-19453)
│ ├─8 html "</a>" (338:38-338:42, 19453-19457)
│ └─9 text ")." (338:42-338:44, 19457-19459)
├─29 paragraph[2] (340:1-346:55, 19461-19925)
│ ├─0 strong[1] (340:1-340:38, 19461-19498)
│ │ └─0 text "Sampling and gradient estimation." (340:3-340:36, 19463-19496)
│ └─1 text " Estimating the above objectives\ninvolves the expectation over $q_{\\bm{\\theta}}(\\cdot|{\\bm{x}})$ or\n$p(\\cdot|{\\bm{x}})$, which should be expanded recursively. Once the\nrecursion depth exceeds $1$, we can not analytically compute\n$D_{\\mathrm{KL}}$ but hinge on Monte Carlo approximation. When sampling\nfrom $q_{\\bm{\\theta}}(\\cdot|{\\bm{x}})$, we should differentiate through\nthe sampling process for unbiased gradient estimation." (340:38-346:55, 19498-19925)
├─30 paragraph[1] (348:1-350:68, 19927-20127)
│ └─0 text "However, this leads to policy gradient-style estimators and should rely\non elaborate policies such as reward hacking and single-step\nregularization to reduce gradient variance and stabilize training ." (348:1-350:68, 19927-20127)
├─31 paragraph[5] (352:1-362:28, 20129-20807)
│ ├─0 text "In comparison, a more straightforward approach is to omit the\ndifferentiation through the sampling process , where the sample\n${\\bm{y}}$ is directly plugged into the objective: $$\\label{eq:offline}\n\\small\n\\ell({\\bm{x}}, {\\bm{\\theta}}) \\approx\n\\sum_{j =1}^{|{\\bm{y}}|+1} D({p(y|{\\bm{x}}, {\\bm{y}}" (352:1-357:54, 20129-20433)
│ ├─1 emphasis[1] (357:54-357:71, 20433-20450)
│ │ └─0 text "{<j})} \\Vert {q" (357:55-357:70, 20434-20449)
│ ├─2 text "{\\bm{\\theta}}(y|{\\bm{x}}, {\\bm{y}}" (357:71-357:105, 20450-20484)
│ ├─3 emphasis[1] (357:105-361:6, 20484-20714)
│ │ └─0 text "{<j})} ).$$\nThis way, various distance measures can be readily applied. Besides, the\nsampling becomes disentangled from the distance measure. i.e., we sample\n${\\bm{y}}$ from an arbitrary mixture of ${p}(\\cdot|{\\bm{x}})$ and\n${q}" (357:106-361:5, 20485-20713)
│ └─4 text "\\theta(\\cdot|{\\bm{x}})$ but use KL, RKL or JSD for estimating the\ndistribution mis-alignment." (361:6-362:28, 20714-20807)
├─32 paragraph[3] (364:1-369:111, 20809-21213)
│ ├─0 text "Intuitively, the samples from the teacher model are usually coherent,\nwhich may raise difficulties in fitting the small student model, while\nsamples from the student model may be less structured or even\nmeaningless. A workaround strategy is to trade off between them via\nmixed sampling , i.e.,\n$y_j \\sim \\beta{p(\\cdot|{\\bm{x}}, {\\bm{y}}" (364:1-369:43, 20809-21145)
│ ├─1 emphasis[1] (369:43-369:65, 21145-21167)
│ │ └─0 text "{<j})} + (1-\\beta) q" (369:44-369:64, 21146-21166)
│ └─2 text "{\\bm{\\theta}}(\\cdot|{\\bm{x}}, {\\bm{y}}_{<j})$." (369:65-369:111, 21167-21213)
├─33 heading[1] (371:1-371:33, 21215-21247)
│ │ depth: 2
│ └─0 text "Online Knowledge Distillation" (371:4-371:33, 21218-21247)
├─34 paragraph[1] (373:1-378:47, 21249-21636)
│ └─0 text "This section expands the application of knowledge distillation for\nspeculative decoding in online environments. The approach enables\nimproving the performance of draft model using results from speculative\ndecoding, thus dynamically adapting to the query distribution and\nimproving token acceptance rate. We also discuss the trade-off of our\napproach when integrating LLM serving systems." (373:1-378:47, 21249-21636)
├─35 heading[1] (380:1-380:14, 21638-21651)
│ │ depth: 3
│ └─0 text "Algorithm" (380:5-380:14, 21642-21651)
├─36 html "<div class=\"algorithm\">" (382:1-382:24, 21653-21676)
├─37 html "<div class=\"algorithmic\">" (384:1-384:26, 21678-21703)
├─38 paragraph[1] (386:1-390:17, 21705-21965)
│ └─0 text "Target LLM $p(\\cdot|{\\bm{x}})$, draft LLM\n$q_{\\bm{\\theta}}(\\cdot|{\\bm{x}})$, warmup dataset $\\mathcal{D}$, online\ndata stream $\\mathcal{S}$, guess number $k$, temporary buffer\n$\\mathcal{R}$, replay buffer $\\mathcal{Q}$, update interval for the\ndraft model $I$." (386:1-390:17, 21705-21965)
├─39 html "</div>" (392:1-392:7, 21967-21973)
├─40 html "</div>" (394:1-394:7, 21975-21981)
├─41 paragraph[9] (396:1-411:26, 21983-22990)
│ ├─0 text "We depict our online speculative decoding algorithm (OSD) in\n" (396:1-397:1, 21983-22044)
│ ├─1 html "<a href=\"#algo:1\" data-reference-type=\"ref\"\ndata-reference=\"algo:1\">" (397:1-398:25, 22044-22112)
│ ├─2 text "[algo:1]" (398:25-398:33, 22112-22120)
│ ├─3 html "</a>" (398:33-398:37, 22120-22124)
│ ├─4 text ". OSD begins by training the draft\nmodel using the warmup dataset (Line 2). The serving system then\ncontinuously handles incoming requests (as described in Lines 6 to 23).\nFor each request, it uses standard speculative decoding (Lines 10-11) to\ngenerate responses until the ⟨EOS⟩ token. Concurrently, OSD tracks the\ntoken index ($error_index$) and target logits where the draft model\nproposes the wrong tokens (Line 15). Leveraging tracked information, OSD\nupdates the draft model every $I$ iteration, with $I$ being a\ndynamically adjustable parameter. OSD updates the draft model with\ndifferent loss functions (Line 20) as described in\nSection " (398:37-408:9, 22124-22770)
│ ├─5 html "<a href=\"#sec:knowledge-distill\" data-reference-type=\"ref\"\ndata-reference=\"sec:knowledge-distill\">" (408:9-409:40, 22770-22868)
│ ├─6 text "4.1" (409:40-409:43, 22868-22871)
│ ├─7 html "</a>" (409:43-409:47, 22871-22875)
│ └─8 text ". The choice of loss\nfunction depends on the specific (draft, target) model pairs and the\ncorresponding input data." (409:47-411:26, 22875-22990)
├─42 paragraph[2] (413:1-430:35, 22992-24221)
│ ├─0 strong[1] (413:1-413:16, 22992-23007)
│ │ └─0 text "Discussion." (413:3-413:14, 22994-23005)
│ └─1 text " OSD utilizes a replay buffer, $\\mathcal{Q}$, to capture\nall pertinent information for updating the draft model. Various eviction\npolicies can be employed to maintain a compact size for $\\mathcal{Q}$.\nFor example, one could opt to retain only the most informative pairs or\nthe most recent entries. Similarly, users have the option to retain data\nin $\\mathcal{Q}$ even after utilizing it to update the model multiple\ntimes. Determining the optimal eviction/retention strategy is a subject\nfor future exploration. In the current study, we refrain from evicting\nany pairs and release $\\mathcal{Q}$ after each model update.\nFurthermore, $I$ is a dynamic parameter. Depending on the system load\nand the rate at which the query distribution changes, users can adjust\n$I$ accordingly. For example, we can perform a gradient update\nopportunistically only when the service traffic is not on spike (i.e.,\nspare flops are available). Overall, OSD continuously improves the draft\nmodel’s approximation (indicated by increased token acceptance rate\n$\\alpha$) by learning from the target model during the serving phase. We\nnext demonstrate how the enhanced acceptance rate directly contributes\nto a reduction in request latency." (413:16-430:35, 23007-24221)
├─43 heading[1] (432:1-432:29, 24223-24251)
│ │ depth: 3
│ └─0 text "Latency & Flops Analysis" (432:5-432:29, 24227-24251)
├─44 paragraph[10] (434:1-447:67, 24253-25114)
│ ├─0 strong[1] (434:1-434:13, 24253-24265)
│ │ └─0 text "Latency." (434:3-434:11, 24255-24263)
│ ├─1 text " As detailed in\nAppendix " (434:13-435:10, 24265-24290)
│ ├─2 html "<a href=\"#appendix:latency-analysis\" data-reference-type=\"ref\"\ndata-reference=\"appendix:latency-analysis\">" (435:10-436:44, 24290-24396)
│ ├─3 text "7.2" (436:44-436:47, 24396-24399)
│ ├─4 html "</a>" (436:47-436:51, 24399-24403)
│ ├─5 text ", compared with\nstandard speculative decoding, the expected speedup for online\nspeculative decoding is\n$\\frac{1+\\alpha_2+\\alpha_2^2+...+\\alpha_2^{k}}{1+\\alpha_1+\\alpha_1^2+...+\\alpha_1^k}$.\nBased on the data from our experiment (refer to\nTable " (436:51-441:7, 24403-24647)
│ ├─6 html "<a href=\"#tab:apha\" data-reference-type=\"ref\"\ndata-reference=\"tab:apha\">" (441:7-442:27, 24647-24719)
│ ├─7 text "1" (442:27-442:28, 24719-24720)
│ ├─8 html "</a>" (442:28-442:32, 24720-24724)
│ └─9 text "), when compared to standard speculative\ndecoding, we expect a speedup improvement for Vicuna-7B (LLaMA-160M as\nthe draft model) by factors of $2.42\\times$, $1.43\\times$, $1.64\\times$,\nand $1.22\\times$. Similarly, for Flan-T5-XL 3B (T5-small 80M as the\ndraft model), the speedup enhancements are $3.06\\times$, $1.76\\times$,\n$2.72\\times$, and $1.55\\times$ across the four evaluated datasets." (442:32-447:67, 24724-25114)
├─45 paragraph[10] (449:1-464:48, 25116-26135)
│ ├─0 strong[1] (449:1-449:11, 25116-25126)
│ │ └─0 text "FLOPs." (449:3-449:9, 25118-25124)
│ ├─1 text " (1) The FLOPs required to update the draft model are\nsignificantly fewer than those needed for inference on a large model. As\nelaborated in\nAppendix " (449:11-452:10, 25126-25276)
│ ├─2 html "<a href=\"#appendix:flops\" data-reference-type=\"ref\"\ndata-reference=\"appendix:flops\">" (452:10-453:33, 25276-25360)
│ ├─3 text "7.3" (453:33-453:36, 25360-25363)
│ ├─4 html "</a>" (453:36-453:40, 25363-25367)
│ ├─5 text ", for the two evaluated model\npairs, the FLOPs ratio between the target model and the draft model is\n18.75 for the pair (LLaMA-160M, Vicuna7B), and 12.6 for the pair\n(T5-small 80M, Flan-T5-XL 3B). (2) In practical systems, the FLOPs\nrequired for inference are significantly below the machine’s capacity.\nThe Appendix " (453:40-458:14, 25367-25684)
│ ├─6 html "<a href=\"#appendix:flops\" data-reference-type=\"ref\"\ndata-reference=\"appendix:flops\">" (458:14-459:33, 25684-25768)
│ ├─7 text "7.3" (459:33-459:36, 25768-25771)
│ ├─8 html "</a>" (459:36-459:40, 25771-25775)
│ └─9 text " provides an analysis of Arena\nchatbot traces where the cluster’s computational utilization is under 1\npercent. Given the above two observations, it becomes evident that the\nFLOPs spent on inference and updating the draft model are relatively\ninsignificant when juxtaposed with the FLOPs consumed while operating\nthe target model and the cluster’s total FLOPs." (459:40-464:48, 25775-26135)
├─46 heading[1] (466:1-466:14, 26137-26150)
│ │ depth: 1
│ └─0 text "Experiments" (466:3-466:14, 26139-26150)
├─47 paragraph[1] (468:1-483:57, 26152-27258)
│ └─0 text "To assess the efficacy of our method, we initially evaluate its ability\nto improve the token acceptance rate ($\\alpha$) within an offline\ncontext. This provides us with a theoretical upper bound on the\nperformance improvements achievable when the query distribution remains\nconstant. Subsequently, we examine the approach’s impact in an online\nenvironment, discovering that the acceptance rate improves even with a\nmoderate amount of data while maintaining accuracy levels comparable to\nthose in the offline scenario. Throughout our experiments, we employ two\ntarget models ($M_p)$: Vicuna-7B  and FLAN-T5-XL (3B) . Specifically for\nVicuna-7B, we utilize LLaMA-160m  as the draft model ($M_q$). For\nFLAN-T5-XL, we use T5-Small  as the draft model. We evaluate performance\nacross four diverse datasets: Text-to-SQL (Spider) , graduate school\nmath (Gsm8k) , Python code generation (Code-search-Python) , and\nfinancial question answering (Alpaca-finance) . In all experiments, we\nset the number of proposed tokens to 5 for speculative decoding. For all\nonline experiments, we fix the update interval $I$ at 8." (468:1-483:57, 26152-27258)
├─48 heading[1] (485:1-485:22, 27260-27281)
│ │ depth: 2
│ └─0 text "Offline Evaluation" (485:4-485:22, 27263-27281)
├─49 paragraph[9] (487:1-503:64, 27283-28358)
│ ├─0 text "In this section, we assess the efficacy of employing knowledge\ndistillation to train a small model specifically for speculation in an\noffline environment. In such a setting, the speculative $M_q$ model has\nunrestricted access to the dataset, and the query distribution remains\nstable. To emulate these offline conditions, we distill the $M_q$ using\nthe training dataset for two epochs and subsequently evaluate its\nperformance by measuring the average token acceptance rate ($\\alpha$) on\nthe test set. As detailed in\nSection " (487:1-495:9, 27283-27808)
│ ├─1 html "<a href=\"#sec:knowledge-distill\" data-reference-type=\"ref\"\ndata-reference=\"sec:knowledge-distill\">" (495:9-496:40, 27808-27906)
│ ├─2 text "4.1" (496:40-496:43, 27906-27909)
│ ├─3 html "</a>" (496:43-496:47, 27909-27913)
│ ├─4 text ", we evaluated various\nsampling methods, namely teacher sampling, student sampling, and mix\ntoken-level sampling.\nTable " (496:47-499:7, 27913-28033)
│ ├─5 html "<a href=\"#tab:apha\" data-reference-type=\"ref\"\ndata-reference=\"tab:apha\">" (499:7-500:27, 28033-28105)
│ ├─6 text "1" (500:27-500:28, 28105-28106)
│ ├─7 html "</a>" (500:28-500:32, 28106-28110)
│ └─8 text " displays the token acceptance rate of\nthe draft model for each method, using forward KL as the distance metric\non the test dataset. For comparison, we also provide the acceptance rate\nfor teacher-generated label fine-tuning and the original model." (500:32-503:64, 28110-28358)
├─50 paragraph[1] (505:1-516:27, 28360-29145)
│ └─0 text "For both the Vicuna-7B and FLAN-T5-XL models, the teacher sampling\nmethod outperforms others by achieving the highest acceptance rate.\nFurthermore, knowledge distillation has proven its efficacy in enhancing\nthe draft model’s approximation, resulting in a high token acceptance\nrate. Intriguingly, we also find that fine-tuning with teacher-generated\nlabels yields impressive performance on the Vicuna-7B model. Lastly, we\nexperimented with different distance measurements like reverse KL and\nJSD. Nevertheless, these measurements either paralleled or\nunderperformed when compared to forward KL. Such empirical findings\nunderscore that the optimal distance measurement or sampling method\nvaries depending on the task and model, and we leave to future work to\nfind the best combination." (505:1-516:27, 28360-29145)
├─51 html "<div class=\"small\">" (518:1-518:20, 29147-29166)
├─52 html "<div class=\"center\">" (520:1-520:21, 29168-29188)
├─53 html "<div id=\"tab:apha\">" (522:1-522:20, 29190-29209)
├─54 paragraph[33] (524:1-533:94, 29211-30150)
│ ├─0 text "| " (524:1-524:3, 29211-29213)
│ ├─1 strong[1] (524:3-524:12, 29213-29222)
│ │ └─0 text "Model" (524:5-524:10, 29215-29220)
│ ├─2 text " | " (524:12-524:16, 29222-29226)
│ ├─3 strong[1] (524:16-524:24, 29226-29234)
│ │ └─0 text "Task" (524:18-524:22, 29228-29232)
│ ├─4 text " | " (524:24-524:37, 29234-29247)
│ ├─5 strong[1] (524:37-524:49, 29247-29259)
│ │ └─0 text "Original" (524:39-524:47, 29249-29257)
│ ├─6 text " | " (524:49-524:52, 29259-29262)
│ ├─7 strong[1] (524:52-524:58, 29262-29268)
│ │ └─0 text "FT" (524:54-524:56, 29264-29266)
│ ├─8 text " | " (524:58-524:63, 29268-29273)
│ ├─9 strong[1] (524:63-524:69, 29273-29279)
│ │ └─0 text "TF" (524:65-524:67, 29275-29277)
│ ├─10 text " | " (524:69-524:74, 29279-29284)
│ ├─11 strong[1] (524:74-524:80, 29284-29290)
│ │ └─0 text "SF" (524:76-524:78, 29286-29288)
│ ├─12 text " | " (524:80-524:84, 29290-29294)
│ ├─13 strong[1] (524:84-524:92, 29294-29302)
│ │ └─0 text "MixF" (524:86-524:90, 29296-29300)
│ ├─14 text " |\n|:-----------|:-------------------|:-------------|:---------|:---------|:--------|:---------|\n| Vicuna-7B | Spider | 0.28 | 0.74 | " (524:92-526:63, 29302-29461)
│ ├─15 strong[1] (526:63-526:71, 29461-29469)
│ │ └─0 text "0.76" (526:65-526:69, 29463-29467)
│ ├─16 text " | 0.62 | 0.70 |\n| | Gsm8k | 0.58 | 0.74 | " (526:71-527:63, 29469-29555)
│ ├─17 strong[1] (527:63-527:71, 29555-29563)
│ │ └─0 text "0.75" (527:65-527:69, 29557-29561)
│ ├─18 text " | 0.67 | 0.73 |\n| | Code-search-Python | 0.38 | " (527:71-528:52, 29563-29638)
│ ├─19 strong[1] (528:52-528:60, 29638-29646)
│ │ └─0 text "0.65" (528:54-528:58, 29640-29644)
│ ├─20 text " | " (528:60-528:63, 29646-29649)
│ ├─21 strong[1] (528:63-528:71, 29649-29657)
│ │ └─0 text "0.65" (528:65-528:69, 29651-29655)
│ ├─22 text " | 0.51 | 0.61 |\n| | Alpaca-finance | 0.57 | " (528:71-529:52, 29657-29732)
│ ├─23 strong[1] (529:52-529:60, 29732-29740)
│ │ └─0 text "0.68" (529:54-529:58, 29734-29738)
│ ├─24 text " | 0.67 | 0.63 | 0.65 |\n| FLAN T5-XL | Spider | 0.13 | 0.33 | " (529:60-530:63, 29740-29837)
│ ├─25 strong[1] (530:63-530:71, 29837-29845)
│ │ └─0 text "0.78" (530:65-530:69, 29839-29843)
│ ├─26 text " | 0.67 | 0.70 |\n| | Gsm8k | 0.29 | 0.50 | " (530:71-531:63, 29845-29931)
│ ├─27 strong[1] (531:63-531:71, 29931-29939)
│ │ └─0 text "0.62" (531:65-531:69, 29933-29937)
│ ├─28 text " | 0.51 | 0.55 |\n| | Code-search-Python | 0.28 | 0.44 | " (531:71-532:63, 29939-30025)
│ ├─29 strong[1] (532:63-532:71, 30025-30033)
│ │ └─0 text "0.81" (532:65-532:69, 30027-30031)
│ ├─30 text " | 0.67 | 0.78 |\n| | Alpaca-finance | 0.39 | 0.56 | " (532:71-533:63, 30033-30119)
│ ├─31 strong[1] (533:63-533:71, 30119-30127)
│ │ └─0 text "0.63" (533:65-533:69, 30121-30125)
│ └─32 text " | 0.59 | 0.60 |" (533:71-533:94, 30127-30150)
├─55 paragraph[5] (535:1-537:50, 30152-30345)
│ ├─0 text "Token acceptance rates ($\\alpha$) after two epochs. " (535:1-535:53, 30152-30204)
│ ├─1 strong[1] (535:53-535:59, 30204-30210)
│ │ └─0 text "FT" (535:55-535:57, 30206-30208)
│ ├─2 text ": Finetuning\non teacher-generated labels. " (535:59-536:30, 30210-30252)
│ ├─3 strong[1] (536:30-536:46, 30252-30268)
│ │ └─0 text "TF, SF, MixF" (536:32-536:44, 30254-30266)
│ └─4 text ": Teacher, student, and mix\ntoken sampling respectively, all with forward KL." (536:46-537:50, 30268-30345)
├─56 html "</div>" (539:1-539:7, 30347-30353)
├─57 html "</div>" (541:1-541:7, 30355-30361)
├─58 html "</div>" (543:1-543:7, 30363-30369)
├─59 heading[1] (545:1-545:21, 30371-30391)
│ │ depth: 2
│ └─0 text "Online Evaluation" (545:4-545:21, 30374-30391)
├─60 paragraph[2] (547:1-553:59, 30393-30875)
│ ├─0 strong[1] (547:1-547:21, 30393-30413)
│ │ └─0 text "Online Learning." (547:3-547:19, 30395-30411)
│ └─1 text " First, we evaluate the effectiveness of our online\nalgorithm by addressing two key questions: (1) Does the online algorithm\nincrease the token acceptance rate? And is this enhancement comparable\nto the rates achieved in offline settings, which serve as an upper bound\ngiven their full access to data? (2) How quickly does the online\nalgorithm increase the token acceptance rate, thereby indicating that\nthe compact model has grasped the underlying distribution?" (547:21-553:59, 30413-30875)
├─61 paragraph[5] (555:1-568:71, 30877-31791)
│ ├─0 text "In our approach, we replicate the online serving process by iterating\nthrough the datasets, extracting prompts, and streaming generation\nrequests. The system utilizes speculative decoding for each of these\nrequests. Throughout this serving phase, we continually refine the\nspeculative models, as detailed in\nAlgorithm " (555:1-560:11, 30877-31195)
│ ├─1 html "<a href=\"#algo:1\" data-reference-type=\"ref\"\ndata-reference=\"algo:1\">" (560:11-561:25, 31195-31263)
│ ├─2 text "[algo:1]" (561:25-561:33, 31263-31271)
│ ├─3 html "</a>" (561:33-561:37, 31271-31275)
│ └─4 text ". For our baseline, we envision a\nscenario where the serving system has the capability to collect data\noffline in order to distill an initial draft model. This model is\nsubsequently deployed online to cater to future requests. This process\nis simulated by using 10% of the dataset to distill the draft model,\nwhich remains static during online serving. For evaluation metrics, we\ncalculate token acceptance rates averaged over the most recent 50\nrequests. This demonstrates $M_q$’s efficacy on the most current data." (561:37-568:71, 31275-31791)
├─62 html "<figure id=\"fig:alphas\">\n<p><embed src=\"figures/legend_figure1.pdf\" /> <embed\nsrc=\"figures/spider_vicuna.pdf\" /> <embed\nsrc=\"figures/gsm8k_vicuna.pdf\" /> <embed\nsrc=\"figures/python_vicuna.pdf\" /> <embed\nsrc=\"figures/finance_vicuna.pdf\" /> <embed\nsrc=\"figures/spider_flant5xl_to_t5small.pdf\" /> <embed\nsrc=\"figures/gsm8k_flant5xl_to_t5small.pdf\" /> <embed\nsrc=\"figures/python_flant5xl_to_t5small.pdf\" /> <embed\nsrc=\"figures/finance_flant5xl_to_t5small.pdf\" /></p>\n<figcaption>Online acceptance rate (<span\nclass=\"math inline\"><em>α</em></span>) across different datasets. The\nx-axis represents the number of records that OSD has processed. Alpha is\naveraged over the most recent 50 records.</figcaption>\n</figure>" (570:1-584:10, 31793-32505)
├─63 html "<figure id=\"fig:dis-shift\">\n<embed src=\"figures/sharp.pdf\" />\n<figcaption>Distribution Shift: Alpha is averaged over the most recent\n100 records.</figcaption>\n</figure>" (586:1-590:10, 32507-32675)
├─64 paragraph[1] (592:1-605:18, 32677-33600)
│ └─0 text "As depicted in Figure 2, both for Vicuna-7B and FLAN-T5, in the\nbeginning, OSD yields a lower token acceptance rate in comparison to the\noffline distilled model. Nevertheless, these acceptance rates rise\nswiftly as the draft model is exposed to more data. We also annotate the\ntoken acceptance rate from the offline setting to highlight the\npotential peak performance that the online serving system could reach.\nIn all instances, the online context can achieve comparable results. In\nsome scenarios, OSD even surpasses the token acceptance rate of the\noffline test alphas. This discrepancy can be attributed to the fact that\noffline test alphas are assessed on the entire test dataset, whereas the\nonline alphas represent the moving average of the latest 50 requests.\nIt’s plausible that OSD performs optimally on specific data subsets,\nparticularly if those subsets are more narrowly distributed than the\ncomplete dataset." (592:1-605:18, 32677-33600)
├─65 paragraph[10] (607:1-618:15, 33602-34331)
│ ├─0 strong[1] (607:1-607:25, 33602-33626)
│ │ └─0 text "Distribution Shifts." (607:3-607:23, 33604-33624)
│ ├─1 text " We evaluate OSD’s ability to adapt to changes\nin data distribution. We detail the dataset preparation in\nAppendix " (607:25-609:10, 33626-33741)
│ ├─2 html "<a href=\"#appendix:distribution-shift\" data-reference-type=\"ref\"\ndata-reference=\"appendix:distribution-shift\">" (609:10-610:46, 33741-33851)
│ ├─3 text "[appendix:distribution-shift]" (610:46-610:75, 33851-33880)
│ ├─4 html "</a>" (610:75-610:79, 33880-33884)
│ ├─5 text ".\nAs illustrated in\nFigure " (610:79-612:8, 33884-33911)
│ ├─6 html "<a href=\"#fig:dis-shift\" data-reference-type=\"ref\"\ndata-reference=\"fig:dis-shift\">" (612:8-613:32, 33911-33993)
│ ├─7 text "4" (613:32-613:33, 33993-33994)
│ ├─8 html "</a>" (613:33-613:37, 33994-33998)
│ └─9 text ", OSD’s alpha value dips notably at\ndistribution boundaries, especially around 2K, 4K, and 6K records. This\nis anticipated since the draft model initially struggles when faced with\na new distribution. However, the alpha value rebounds quickly as OSD\nprocesses more data, highlighting its adaptability to shifting query\ndistributions." (613:37-618:15, 33998-34331)
├─66 paragraph[5] (620:1-632:10, 34333-35152)
│ ├─0 text "We also compared our results to those from a static setting. To ensure\nthe draft model wasn’t just memorizing data, we chose samples distinct\nfrom the online evaluation data. These samples correspond to 30%, 50%,\n70%, and 100% of each dataset’s online evaluation volume, at 0.6K, 1K,\n1.4K, and 2K quantities respectively. As depicted in\nFigure " (620:1-625:8, 34333-34677)
│ ├─1 html "<a href=\"#fig:dis-shift\" data-reference-type=\"ref\"\ndata-reference=\"fig:dis-shift\">" (625:8-626:32, 34677-34759)
│ ├─2 text "4" (626:32-626:33, 34759-34760)
│ ├─3 html "</a>" (626:33-626:37, 34760-34764)
│ └─4 text ", upon an initial shift in query\ndistribution, OSD’s performance aligns with or slightly trails the\ndistillation with 30% data. However, it quickly catches up, matching or\neven surpassing performances seen with 70% to 100% data access. This\nhighlights OSD’s ability to rival models fully exposed to the query\ndistribution, even without intimate knowledge of the underlying query\ndynamics." (626:37-632:10, 34764-35152)
├─67 paragraph[14] (634:1-651:63, 35154-36337)
│ ├─0 strong[1] (634:1-634:20, 35154-35173)
│ │ └─0 text "Real Workloads." (634:3-634:18, 35156-35171)
│ ├─1 text " We evaluate OSD on real LMSYS-chat conversations\n(Appendix  " (634:20-635:12, 35173-35234)
│ ├─2 html "<a href=\"#appendix:arena\" data-reference-type=\"ref\"\ndata-reference=\"appendix:arena\">" (635:12-636:33, 35234-35318)
│ ├─3 text "7.6" (636:33-636:36, 35318-35321)
│ ├─4 html "</a>" (636:36-636:40, 35321-35325)
│ ├─5 text ") that span 4 months. First, we\ncategorize conversations based on the language and we focus on\nconversations among the top five languages, excluding English. For every\nchosen language, we use an independent LLaMA-160M to serve as our draft\nmodel. All draft models share the same Vicuna-7B as the target model.\nThe token acceptance rate, averaged over the latest 100 requests, showed\nin Figure " (636:40-642:11, 35325-35718)
│ ├─6 html "<a href=\"#fig:arena\" data-reference-type=\"ref\"\ndata-reference=\"fig:arena\">" (642:11-643:28, 35718-35792)
│ ├─7 text "5" (643:28-643:29, 35792-35793)
│ ├─8 html "</a>" (643:29-643:33, 35793-35797)
│ ├─9 text ", reveals that OSD’s enhances rates by\n0.1 to 0.2, even with under 2K data points. Notably, Japanese was the\neasiest while Portuguese was the toughest. We also clustered English\nconversations by topics using the fine-tuned distilled Bert model ,\nfocusing on the top five. For topics with over 5K conversations, we\nsampled evenly to keep it within 5K.\nFigure " (643:33-649:8, 35797-36155)
│ ├─10 html "<a href=\"#fig:arena\" data-reference-type=\"ref\"\ndata-reference=\"fig:arena\">" (649:8-650:28, 36155-36229)
│ ├─11 text "5" (650:28-650:29, 36229-36230)
│ ├─12 html "</a>" (650:29-650:33, 36230-36234)
│ └─13 text " shows acceptance rates above 0.6 across\ntopics, with Social and Computer discussions peaking near 0.8." (650:33-651:63, 36234-36337)
├─68 html "<figure id=\"fig:arena\">\n<p><embed src=\"figures/arena_language.pdf\" /> <embed\nsrc=\"figures/arena_class.pdf\" /></p>\n<figcaption>Chatbot Arena Conversations clustered by language and\ntopic.</figcaption>\n</figure>" (653:1-658:10, 36339-36548)
├─69 html "<figure id=\"fig:freq-acc\">\n<p><embed src=\"figures/precision.pdf\" /> <embed\nsrc=\"figures/recall.pdf\" /></p>\n<figcaption>Precision and recall of high-frequency tokens. The x-axis\nshows the rating of the tokens based on their occurrence in the\ngenerated answers. For instance, token 1 appears most frequently in\nanswers. Precision = # of times token <span\nclass=\"math inline\"><em>i</em></span> is accepted by the target model /\n# of times token <span class=\"math inline\"><em>i</em></span> is proposed\nby the draft model. Recall = # of times token <span\nclass=\"math inline\"><em>i</em></span> is accepted by the target model /\n# of times token <span class=\"math inline\"><em>i</em></span> appears in\nthe final answer.</figcaption>\n</figure>" (660:1-673:10, 36550-37284)
├─70 heading[1] (675:1-675:24, 37286-37309)
│ │ depth: 2
│ └─0 text "Qualitative Analysis" (675:4-675:24, 37289-37309)
├─71 paragraph[1] (677:1-679:57, 37311-37506)
│ └─0 text "In this section, we conduct a comprehensive analysis to understand how\nour method enhances the token acceptance rate, and which tokens the\ndraft model acquires across varying query distributions." (677:1-679:57, 37311-37506)
├─72 paragraph[6] (681:1-689:34, 37508-38042)
│ ├─0 strong[1] (681:1-681:48, 37508-37555)
│ │ └─0 text "High-frequency tokens precision and recall." (681:3-681:46, 37510-37553)
│ ├─1 text " In our experiment using\nthe Spider dataset, Vicuna-7M is the target model and LLaMA-160M the\ndraft. We identify the top 100 tokens most frequently generated by the\ntarget model, which account for 72.2% of all appearances, following a\npower-law distribution.\nFigure " (681:48-686:8, 37555-37821)
│ ├─2 html "<a href=\"#fig:freq-acc\" data-reference-type=\"ref\"\ndata-reference=\"fig:freq-acc\">" (686:8-687:31, 37821-37901)
│ ├─3 text "6" (687:31-687:32, 37901-37902)
│ ├─4 html "</a>" (687:32-687:36, 37902-37906)
│ └─5 text " shows a marked improvement in both\naccuracy and recall of these tokens after distillation on the test\ndataset in an offline evaluation." (687:36-689:34, 37906-38042)
├─73 html "<div class=\"center\">" (691:1-691:21, 38044-38064)
├─74 html "<div class=\"footnotesize\">" (693:1-693:27, 38066-38092)
├─75 html "<div id=\"tab:tokens\">" (695:1-695:22, 38094-38115)
├─76 paragraph[15] (697:1-700:402, 38117-39724)
│ ├─0 text "| " (697:1-697:3, 38117-38119)
│ ├─1 strong[1] (697:3-697:14, 38119-38130)
│ │ └─0 text "Dataset" (697:5-697:12, 38121-38128)
│ ├─2 text " | " (697:14-697:53, 38130-38169)
│ ├─3 strong[1] (697:53-697:63, 38169-38179)
│ │ └─0 text "Spider" (697:55-697:61, 38171-38177)
│ ├─4 text " | " (697:63-697:138, 38179-38254)
│ ├─5 strong[1] (697:138-697:147, 38254-38263)
│ │ └─0 text "Gsm8k" (697:140-697:145, 38256-38261)
│ ├─6 text " | " (697:147-697:218, 38263-38334)
│ ├─7 strong[1] (697:218-697:236, 38334-38352)
│ │ └─0 text "Alpaca-Finance" (697:220-697:234, 38336-38350)
│ ├─8 text " | " (697:236-697:316, 38352-38432)
│ ├─9 strong[1] (697:316-697:331, 38432-38447)
│ │ └─0 text "Code-Python" (697:318-697:329, 38434-38445)
│ ├─10 text " |\n|:------------------------------------------------|:-----------------------------------------------------------------------------------|:------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------|\n| " (697:331-699:3, 38447-38923)
│ ├─11 strong[1] (699:3-699:50, 38923-38970)
│ │ └─0 text "Tokens with the greatest precision increase" (699:5-699:48, 38925-38968)
│ ├─12 text " | AV, SELECT, first, ⟨EOS⟩, template, SUM, G, COUNT, \\n, city, WHERE, ’;, (, IST, id | ⟨EOS⟩, >>, +, To, <<, this, =, %, know, are, We, calculate, be, The, have | 1, Here, (, :, provide, depends, However, goals, amount, 3, there, The, \\n, personal, will | ”’, (, Here, python, ’, how, doc, snippet, import, based, {, Python, This, :, you |\n| " (699:50-700:3, 38970-39325)
│ ├─13 strong[1] (700:3-700:47, 39325-39369)
│ │ └─0 text "Tokens with the greatest recall increase" (700:5-700:45, 39327-39367)
│ └─14 text " | SELECT, *, FROM, (, IST, *), \\n, COUNT, G, first, WHERE, ⟨EOS⟩, IN, ;, MAX, ’; | start, >>, <<, +, find, how, we, =, fore, To, so, \\ ⟨EOS⟩, then, let | general, 1, several, This, depends, Here, provide, However, goals, over, (, If, amount, it, can | Here, This, snippet, ”’, ’, how, python, (, takes, Python, you, doc, an, import, def |" (700:47-700:402, 39369-39724)
├─77 paragraph[1] (702:1-704:17, 39726-39874)
│ └─0 text "Top 15 tokens with the most recall/precision improvement across\ndatasets. We ignore _ before tokens, which represents space in the\nLLaMA tokenizer." (702:1-704:17, 39726-39874)
├─78 html "</div>" (706:1-706:7, 39876-39882)
├─79 html "</div>" (708:1-708:7, 39884-39890)
├─80 html "</div>" (710:1-710:7, 39892-39898)
├─81 paragraph[6] (712:1-723:58, 39900-40712)
│ ├─0 strong[1] (712:1-712:45, 39900-39944)
│ │ └─0 text "Tokens learned across different datasets" (712:3-712:43, 39902-39942)
│ ├─1 text " In our study, we analyze\nthe top 10 tokens with the most pronounced accuracy and recall\nimprovements across various datasets, focusing on the 100 most frequent\ntokens to understand the draft model’s learning trends. As detailed in\nTable " (712:45-716:7, 39944-40182)
│ ├─2 html "<a href=\"#tab:tokens\" data-reference-type=\"ref\"\ndata-reference=\"tab:tokens\">" (716:7-717:29, 40182-40258)
│ ├─3 text "2" (717:29-717:30, 40258-40259)
│ ├─4 html "</a>" (717:30-717:34, 40259-40263)
│ └─5 text ", the improved tokens align well with\nthe underlying data distribution. For example, in the Spider dataset,\nwhich frequently generates SQL statements, tokens like SELECT and WHERE\nhave notably higher acceptance rates post-distillation. Similarly, in\nthe Graduate Math dataset (Gsm8k), tokens such as <<, >>, =, and +\nstand out. These patterns highlight the draft model’s ability to adapt\nand predict tokens consistent with the data distribution." (717:34-723:58, 40263-40712)
├─82 heading[1] (725:1-725:13, 40714-40726)
│ │ depth: 1
│ └─0 text "Conclusion" (725:3-725:13, 40716-40726)
├─83 paragraph[1] (727:1-732:57, 40728-41120)
│ └─0 text "Speculative decoding’s efficiently hinges on the draft model’s\napproximation to the target model. We introduce an online speculative\nmethod that continuously enhances the draft model based on varying data\ndistributions. Experiments on both synthetic and real data demonstrate\nthat online speculative decoding swiftly adapts to new data\ndistributions, significantly enhancing token acceptance." (727:1-732:57, 40728-41120)
├─84 heading[1] (734:1-734:11, 41122-41132)
│ │ depth: 1
│ └─0 text "Appendix" (734:3-734:11, 41124-41132)
├─85 heading[1] (736:1-736:35, 41134-41168)
│ │ depth: 2
│ └─0 text "Speedup of Speculative Decoding" (736:4-736:35, 41137-41168)
├─86 paragraph[1] (738:1-749:56, 41170-41949)
│ └─0 text "As proved in  , compared with standard decoding, the expected\nimprovement factor for offline speculative decoding is\n$\\frac{1-\\alpha^{k+1}}{(1-\\alpha)(ck+1)}$. Let the time taken for a\nsingle run of $M_p$ be $T$. Define $c$, the cost coefficient, as the\nratio of the time taken for a single run of $M_q$ to that of $M_p$. Each\nexecution of lines 7 to 8 takes $Tck + T$ and, on average, yields\n$\\frac{1-\\alpha^{k+1}}{1-\\alpha}$ tokens. As a result, the average time\nto produce one token using speculative decoding is given by\n$\\frac{(ck+1)(1-\\alpha)}{1-\\alpha^{k+1}}T$. In contrast, the time to\nproduce a single token using standard decoding is $T$. Hence, the\nwallclock time reduction of offline speculative decoding can be\ndescribed as $\\frac{1-\\alpha^{k+1}}{(1-\\alpha)(ck+1)}$." (738:1-749:56, 41170-41949)
├─87 heading[1] (751:1-751:20, 41951-41970)
│ │ depth: 2
│ └─0 text "Latency Analysis" (751:4-751:20, 41954-41970)
├─88 paragraph[5] (753:1-762:43, 41972-42674)
│ ├─0 text "Suppose OSD can improve the token acceptance rate from $\\alpha_1$ to\n$\\alpha_2$ and $T$ is the generation time for standard decoding. Based\non Equation " (753:1-755:13, 41972-42124)
│ ├─1 html "<a href=\"#eq:gen_len\" data-reference-type=\"ref\"\ndata-reference=\"eq:gen_len\">" (755:13-756:29, 42124-42200)
│ ├─2 text "[eq:gen_len]" (756:29-756:41, 42200-42212)
│ ├─3 html "</a>" (756:41-756:45, 42212-42216)
│ └─4 text ", this improvement leads to\na decrease in the average generation time for each token, transitioning\nfrom $\\frac{(ck+1)(1-\\alpha_1)}{1-\\alpha_{1}^{k+1}}T$ to\n$\\frac{(ck+1)(1-\\alpha_2)}{1-\\alpha_{2}^{k+1}}T$. Consequently, this\nresults in a speedup factor of\n$\\frac{1-\\alpha_2^{k+1}}{1-\\alpha_1^{k+1}}\\frac{1-\\alpha_1}{1-\\alpha_2} = \\frac{1+\\alpha_2+\\alpha_2^2+...+\\alpha_2^{k}}{1+\\alpha_1+\\alpha_1^2+...+\\alpha_1^k}$\ncompared to standard speculative decoding." (756:45-762:43, 42216-42674)
├─89 paragraph[1] (764:1-774:69, 42676-43440)
│ └─0 text "In the aforementioned analysis, we omitted the additional latency due to\nupdating the smaller model for the following reasons: (1) As illustrated\nsubsequently, the additional computational cost (FLOPs) from the update\nremains marginal when juxtaposed with the computational demands of\nrunning the larger model. (2) Updates are periodic, during times of\nmoderate request loads, the latency for serving individual requests\nremains largely unaffected. Additionally, given that the update\noperation for the smaller model is considerably less resource-intensive\nthan inference, the associated latency might be seamlessly masked,\nrendering it virtually imperceptible. Lastly, the processes of updating\nand inference can even be executed concurrently on separate devices." (764:1-774:69, 42676-43440)
├─90 heading[1] (776:1-776:18, 43442-43459)
│ │ depth: 2
│ └─0 text "Flops Analysis" (776:4-776:18, 43445-43459)
├─91 paragraph[6] (778:1-803:79, 43461-45107)
│ ├─0 emphasis[1] (778:1-779:51, 43461-43581)
│ │ └─0 text "The FLOPs required to update the draft model are significantly fewer\nthan those needed for inference on a large model." (778:2-779:50, 43462-43580)
│ ├─1 text " Denote $L$ as the\naverage length of the generated sequence. For each verification, the\ndraft model suggests $k$ tokens. The expected length for a single run of\nthe target LLM, denoted as $a$, can be calculated using\nEquation " (779:51-783:10, 43581-43807)
│ ├─2 html "<a href=\"#eq:gen_len\" data-reference-type=\"ref\"\ndata-reference=\"eq:gen_len\">" (783:10-784:29, 43807-43883)
│ ├─3 text "[eq:gen_len]" (784:29-784:41, 43883-43895)
│ ├─4 html "</a>" (784:41-784:45, 43895-43899)
│ └─5 text ". Therefore, OSD undergoes\nthe verification process $\\frac{L}{a}$ times, with each time verifying\n$k+1$ tokens. We use $F_{qfwd}$ to represent the arithmetic operations\nrequired by a singular forward run of the draft model for each token,\nand $F_{pfwd}$ stands for the FLOPs needed for a single forward run of\nthe target model per token. Therefore, the computational demand (in\nFLOPs) for the draft and teacher models to handle one request can be\nexpressed as:\n$\\text{FLOPs}(draft) = \\frac{L}{a} \\times k \\times F_{qfwd},\n\\text{FLOPs}(target) = \\frac{L}{a} \\times (k+1) \\times F_{pfwd}.$ Let’s\nconsider the FLOPs required to update the student model per token as\n$F_{qbwd}$. The cumulative FLOPs necessary to process $I$ requests is\ngiven by:\n$$\\frac{LI}{a} \\times \\left[k \\times F_{qfwd} + (k+1) \\times F_{pfwd}\\right] + I \\times L \\times F_{qbwd}.$$\nBased on the findings of , training is approximately three times\ncostlier than inference. This translates to roughly 6 FLOPs per\nparameter for training on a single token and 2 FLOPs per parameter for\ninferring on one token. Thus, we can simplify the total FLOPs expression\nto:\n$$\\frac{LI}{a}\\left[(k + 3a) \\times F_{qfwd} + (k+1) \\times F_{pfwd}\\right].$$" (784:45-803:79, 43899-45107)
├─92 paragraph[1] (805:1-814:62, 45109-45660)
│ └─0 text "The proportion of FLOPs needed to run the target model to that of the\ndraft model is given by:\n$$\\frac{(k+1)\\times F_{pfwd}}{(k+3a)\\times F_{qfwd}}.$$ For the two\nmodel pairs evaluated, assuming an average of 5 proposed tokens per run:\n(1) (LLaMA-160M, Vicuna7B) with an average acceptance rate of 0.71, the\nratio is approximately\n$\\frac{(5+1) \\times 7B}{(5+3 \\times 3) \\times 160M} = 18.75$. (2)\n(T5-small 80M, Flan-T5-XL 3B), with an average acceptance rate of 0.76,\nthe ratio is roughly\n$\\frac{(5+1) \\times 3B}{(5+3 \\times 4.3) \\times 80M} = 12.6$." (805:1-814:62, 45109-45660)
├─93 paragraph[2] (816:1-830:28, 45662-46665)
│ ├─0 emphasis[1] (816:1-817:45, 45662-45766)
│ │ └─0 text "In practical systems, the FLOPs required for inference are\nsignificantly below the machine’s capacity." (816:2-817:44, 45663-45765)
│ └─1 text " Consider the\nLMSYS-Chat-1M . It comprises traces spanning 125 days with 1000,000\nrequests, averaging less than 2,000 tokens per request (including both\nprompts and responses). When serving a 30B model with 8 A100 GPUs, the\nFLOPs consumed per second can be estimated as (Still, we estimate 2\nFLOPs per token per parameter):\n$$\\frac{2000 \\times 1000,000}{125 \\times 24 \\times 3600} \\times 30 \\times 10^9 \\times 2 = 5.5 \\times 10^9 \\text{ FLOPs or 5.5 GFLOPs}$$\nOn the other hand, 8 A100 GPUs offer a combined capacity of\n$8 \\times 312 \\text{ TFLOPs}$, and the computational utilization is\nnotably low. While Arena (the platform that generates LMSYS-Chat-1M) may\nnot be the most efficient and might lack substantial traffic, it’s the\nonly publicly accessible LLM service trace. Even after amplifying the\nload multiple times, based on the above calculations, the computation\nefficiency remains limited." (817:45-830:28, 45766-46665)
├─94 heading[1] (832:1-832:12, 46667-46678)
│ │ depth: 2
│ └─0 text "Data Mix" (832:4-832:12, 46670-46678)
├─95 paragraph[5] (834:1-852:12, 46680-47896)
│ ├─0 text "Moreover, there is a question of whether the draft model, once adapted\nto the new distribution, might lose its prior knowledge. To probe this,\nwe conducted an experiment mixing 2k prompts each from the Gsm8k and\nAlpaca-finance datasets. During online serving, for the initial 2k\nrequests, we only update the model based on data from the Gsm8k dataset.\nFor the subsequent half of the requests, we restrict updates solely to\ndata from the Alpaca-finance dataset. We then provide the average token\nacceptance rates for all requests, segmented by their data source (Gsm8k\nversus Alpaca-finance). As depicted in\nFigure " (834:1-843:8, 46680-47294)
│ ├─1 html "<a href=\"#fig:mix\" data-reference-type=\"ref\"\ndata-reference=\"fig:mix\">" (843:8-844:26, 47294-47364)
│ ├─2 text "7" (844:26-844:27, 47364-47365)
│ ├─3 html "</a>" (844:27-844:31, 47365-47369)
│ └─4 text ", the token acceptance rate for Gsm8k\nincreases as the draft model is exposed to more data. Conversely, the\nacceptance rate ($\\alpha$) for the Alpaca-finance dataset remains\nconsistent. This is anticipated since we only update the draft model\nusing Gsm8k data. In the latter half of the dataset, the token\nacceptance rate for the Alpaca-finance dataset also shows an uptrend.\nIntriguingly, the rate for Gsm8k remains consistent, suggesting that the\ndraft model retains its learned knowledge without showing signs of\nforgetting." (844:31-852:12, 47369-47896)
├─96 html "<figure id=\"fig:mix\">\n<p><embed src=\"figures/appendix_legend.pdf\" /><br />\n<embed src=\"figures/mix.pdf\" /></p>\n<figcaption>Mix of distributions.</figcaption>\n</figure>" (854:1-858:10, 47898-48065)
├─97 heading[1] (860:1-860:51, 48067-48117)
│ │ depth: 2
│ └─0 text "Data Preparation for Distribution Shift Analysi" (860:4-860:51, 48070-48117)
├─98 paragraph[4] (862:1-867:61, 48119-48459)
│ ├─0 text "To emulate this shift in distribution,\n" (862:1-863:1, 48119-48158)
│ ├─1 html "<span id=\"appendix:distribution-shift\"\nlabel=\"appendix:distribution-shift\">" (863:1-864:37, 48158-48233)
│ ├─2 html "</span>" (864:37-864:44, 48233-48240)
│ └─3 text " we select 2k prompts from\neach dataset under evaluation. T he data from the four datasets are\namalgamated by direct concatenation, such that the records from\n$i\\times2k$ to $(i+1)\\times2k$ belong solely to dataset $i$." (864:44-867:61, 48240-48459)
├─99 heading[1] (869:1-869:17, 48461-48477)
│ │ depth: 2
│ └─0 text "Arena Dataset" (869:4-869:17, 48464-48477)
├─100 paragraph[1] (871:1-876:15, 48479-48841)
│ └─0 text "For expedited experimental evaluation, we randomly sample a subset with\n10K records from LMSYS-Chat-1M , a comprehensive real-world LLM\nconversation dataset. This dataset encompasses interactions with 25\nmodels spanning from April to August 2023 and features conversations in\nover 150 languages. For all experiments, we only pick conversations for\nVicuna models." (871:1-876:15, 48479-48841)
└─101 paragraph[3] (878:1-878:58, 48843-48900)
├─0 text "[^1]: ${{\\bm{y}}}" (878:1-878:18, 48843-48860)
├─1 emphasis[1] (878:18-878:45, 48860-48887)
│ └─0 text "{<i}$ refers to ${ y_j}" (878:19-878:44, 48861-48886)
└─2 text "{j=1}^{i-1}$." (878:45-878:58, 48887-48900)

Introduction

With the increased interest in deep learning in recent years, there has been an explosion of machine learning tools. Many popular frameworks such as Caffe ("Jia et al. "2014"), CNTK (Seide and Agarwal 2016), TensorFlow (Abadi et al. 2015), and Theano (Theano Development Team 2016), construct a static dataflow graph that represents the computation and which can then be applied repeatedly to batches of data. This approach provides visibility into the whole computation ahead of time, and can theoretically be leveraged to improve performance and scalability. However, it comes at the cost of ease of use, ease of debugging, and flexibility of the types of computation that can be represented.

Prior work has recognized the value of dynamic eager execution for deep learning, and some recent frameworks implement this define-by-run approach, but do so either at the cost of performance (Chainer (Tokui et al. 2015)) or using a less expressive, faster language (Torch (Collobert, Bengio, and Mariéthoz 2002), DyNet (Neubig et al. 2017)), which limits their applicability.

However, with careful implementation and design choices, dynamic eager execution can be achieved largely without sacrificing performance. This paper introduces PyTorch, a Python library that performs immediate execution of dynamic tensor computations with automatic differentiation and GPU acceleration, and does so while maintaining performance comparable to the fastest current libraries for deep learning. This combination has turned out to be very popular in the research community with, for instance, 296 ICLR 2019 submissions mentioning PyTorch.

Background

Four major trends in scientific computing have become increasingly important for deep learning.

First, starting in the 1960s, the development of domain specific languages such as APL (Abrams 1970), MATLAB (MATLAB and Statistics Toolbox, n.d.), R (R Core Team, n.d.) and Julia (Bezanson et al. 2017), turned multidimensional arrays (often referred to as tensors) into first-class objects supported by a comprehensive set of mathematical primitives (or operators) to manipulate them. Separately, libraries such as NumPy(Oliphant 2006), Torch(Collobert, Bengio, and Mariéthoz 2002), Eigen(Guennebaud, Jacob, et al. 2010) and Lush(Y. LeCun and Bottou 2002) made array-based programming productive in general purpose languages such as Python, Lisp, C++ and Lua.

Second, the development of automatic differentiation (Baydin et al. 2017) made it possible to fully automate the daunting labor of computing derivatives. This made it significantly easier to experiment with different machine learning approaches while still allowing for efficient gradient based optimization. The autograd (Maclaurin 2016) package popularized the use of this technique for NumPy arrays, and similar approaches are used in frameworks such as Chainer (Tokui et al. 2015), DyNet (Neubig et al. 2017), Lush (Y. LeCun and Bottou 2002), Torch (Collobert, Bengio, and Mariéthoz 2002), Jax (M. J. et. al. 2018) and Flux.jl (M. I. et. al. 2018).

Third, with the advent of the free software movement, the scientific community moved away from closed proprietary software such as Matlab(MATLAB and Statistics Toolbox, n.d.), and towards the open-source Python ecosystem with packages like NumPy (Oliphant 2006), SciPy (Jones et al. 2001--), and Pandas (McKinney 2010). This fulfilled most of the numerical analysis needs of researchers while allowing them to take advantage of a vast repository of libraries to handle dataset preprocessing, statistical analysis, plotting, and more. Moreover, the openness, interoperability, and flexibility of free software fostered the development of vibrant communities that could quickly address new or changing needs by extending the existing functionality of a library or if needed by developing and releasing brand new ones. While there is a rich offering of open-source software for neural networks in languages other than Python, starting with Lush (Y. LeCun and Bottou 2002) in Lisp, Torch (Collobert, Bengio, and Mariéthoz 2002) in C++, Objective-C and Lua, EBLearn (Sermanet, Kavukcuoglu, and LeCun 2009) in C++, Caffe ("Jia et al. "2014") in C++, the network effects of a large ecosystem such as Python made it an essential skill to jumpstart one's research. Hence, since 2014, most deep learning frameworks converged on a Python interface as an essential feature.

Finally, the availability and commoditization of general-purpose massively parallel hardware such as GPUs provided the computing power required by deep learning methods. Specialized libraries such as cuDNN (Chetlur et al. 2014), along with a body of academic work (such as (Lavin 2015) and (Lavin and Gray 2016)), produced a set of high-performance reusable deep learning kernels that enabled frameworks such as Caffe ("Jia et al. "2014"), Torch7 (Collobert, Kavukcuoglu, and Farabet 2011), or TensorFlow (Abadi et al. 2015) to take advantage of these hardware accelerators.

PyTorch builds on these trends by providing an array-based programming model accelerated by GPUs and differentiable via automatic differentiation integrated in the Python ecosystem.

Design principles

PyTorch's success stems from weaving previous ideas into a design that balances speed and ease of use. There are four main principles behind our choices:

Be Pythonic Data scientists are familiar with the Python language, its programming model, and its tools. PyTorch should be a first-class member of that ecosystem. It follows the commonly established design goals of keeping interfaces simple and consistent, ideally with one idiomatic way of doing things. It also integrates naturally with standard plotting, debugging, and data processing tools.

Put researchers first PyTorch strives to make writing models, data loaders, and optimizers as easy and productive as possible. The complexity inherent to machine learning should be handled internally by the PyTorch library and hidden behind intuitive APIs free of side-effects and unexpected performance cliffs.

Provide pragmatic performance To be useful, PyTorch needs to deliver compelling performance, although not at the expense of simplicity and ease of use. Trading 10% of speed for a significantly simpler to use model is acceptable; 100% is not. Therefore, its implementation accepts added complexity in order to deliver that performance. Additionally, providing tools that allow researchers to manually control the execution of their code will empower them to find their own performance improvements independent of those that the library provides automatically.

Worse is better (Gabriel, n.d.) Given a fixed amount of engineering resources, and all else being equal, the time saved by keeping the internal implementation of PyTorch simple can be used to implement additional features, adapt to new situations, and keep up with the fast pace of progress in the field of AI. Therefore it is better to have a simple but slightly incomplete solution than a comprehensive but complex and hard to maintain design.

Usability centric design

Deep learning models are just Python programs

In a surprisingly short amount of time, machine learning grew from recognizing individual digits (Yann LeCun and Cortes, n.d.) into autonomously playing StarCraft (Vinyals et al. 2017). Consequently, the neural networks themselves evolved rapidly from simple sequences of feed forward layers into incredibly varied numerical programs often composed of many loops and recursive functions. To support this growing complexity, PyTorch foregoes the potential benefits of a graph-metaprogramming based approach to preserve the imperative programming model of Python. This design was pioneered for model authoring by Chainer(Tokui et al. 2015) and Dynet(Neubig et al. 2017). PyTorch extends this to all aspects of deep learning workflows. Defining layers, composing models, loading data, running optimizers, and parallelizing the training process are all expressed using the familiar concepts developed for general purpose programming.

This solution ensures that any new potential neural network architecture can be easily implemented with PyTorch. For instance, layers (which in modern machine learning should really be understood as stateful functions with implicit parameters) are typically expressed as Python classes whose constructors create and initialize their parameters, and whose forward methods process an input activation. Similarly, models are usually represented as classes that compose individual layers, but let us state again that nothing forces the user to structure their code in that way. Listing [lst:code_example]{reference-type="ref" reference="lst:code_example"} demonstrates how an entire model can be created by composing functionality provided by PyTorch such as 2d convolution, matrix multiplication, dropout, and softmax to classify gray-scale images. Note that linear layers are of course part of the library, but we show an example implementation to highlight how simple it is.

::: {.parcolumns} 2 :::

[]{#lst:code_example label="lst:code_example"}

This "everything is a just a program" philosophy is not limited to just the models, and applies to optimizers and data loaders as well. This facilitates the experimentation of new training techniques. For example, to implement the very popular generative adversarial networks, one needs to specify two separate models (the generator and the discriminator), and two loss functions that depend on both models at the same time. Rigid APIs would struggle with this setup, but the simple design employed in PyTorch easily adapts to this setting as shown in Listing 1{reference-type="ref" reference="lst:gan"}.

discriminator = create_discriminator()
generator = create_generator()
optimD = optim.Adam(discriminator.parameters())
optimG = optim.Adam(generator.parameters())

def step(real_sample):
  # (1) Update Discriminator
  errD_real = loss(discriminator(real_sample), real_label)
  errD_real.backward()
  fake = generator(get_noise())
  errD_fake = loss(discriminator(fake.detach(), fake_label)
  errD_fake.backward()
  optimD.step()
  # (2) Update Generator
  errG = loss(discriminator(fake), real_label)
  errG.backward()
  optimG.step()

Since PyTorch programs execute eagerly, all the features of Python are available throughout the whole design process. Print statements, standard debuggers, and common visualization tools like matplotlib all work as expected. Users do not have to wait for lengthy compilation before they can start running their programs, and more importantly intermediate computations can be observed to understand how a model works and whether its results are correct.

Interoperability and extensibility

Easy and efficient interoperability is one of the top priorities for PyTorch because it opens the possibility to leverage the rich ecosystem of Python libraries as part of user programs. Hence, PyTorch allows for bidirectional exchange of data with external libraries. For example, it provides a mechanism to convert between NumPy arrays and PyTorch tensors using the torch.from_numpy() function and .numpy() tensor method. Similar functionality is also available to exchange data stored using the DLPack (DMLC, n.d.) format. Note that this exchange happens in both cases without any data copying -- objects on both sides only describe how to interpret a memory region which is shared among them. Hence, those operations are actually extremely cheap, and take constant time no matter how large the converted arrays are.

Moreover, many of the critical systems are designed specifically to be extensible. For instance, the automatic differentiation system allows users to add support for custom differentiable functions. To do that users can define a new subclass of torch.autograd.Function that implements forward() and backward() methods, which specify the function and its derivative (or more formally the vector-Jacobian product). Similarly new datasets can be added by subclassing torch.utils.data.Dataset and implementing two methods: __getitem__ (the indexing operator) and __len__ (the length operator), making datasets behave like (possibly lazy) lists. How these work is completely up to the implementer, and many users leverage other Python packages for data loading. The DataLoader class consumes objects conforming to this interface and provides an iterator over the data which takes care of shuffling, batching, parallelization, and management of pinned CUDA memory to improve throughput.

Most importantly, users are free to replace any component of PyTorch that does not meet the needs or performance requirements of their project. They are all designed to be completely interchangeable, and PyTorch takes great care not to impose any particular solution.

Automatic differentiation

Since gradient based optimization is vital to deep learning, PyTorch must be able to automatically compute gradients of models specified by our users, and those can be arbitrary Python programs. However, Python is a dynamic programming language that allows changing most behaviors at runtime, making ahead of time source-to-source differentiation cumbersome. Instead, PyTorch uses the operator overloading approach, which builds up a representation of the computed function every time it is executed. In its current implementation (Paszke et al. 2017), PyTorch performs reverse-mode automatic differentiation, which computes the gradient of a scalar output with respect to a multivariate input. Differentiating functions with more outputs than inputs is more efficiently executed using forward-mode automatic differentiation, but this use case is less common for machine learning applications. PyTorch can be easily extended to perform forward-mode differentiation using array-level dual numbers (Piponi 2004; Leuck and Nagel 1999).

Another interesting and uncommon feature of our system is that it can differentiate through code employing mutation on tensors, which is one of the basic building blocks of imperative programs. To ensure safety, we have implemented a versioning system for tensors, which lets us track their modifications and ensure that we always use the data we expect. One interesting tradeoff is that while we could utilize techniques like copy-on-write to support arbitrary programs, we chose to not go down this path, as performance-wise it is usually beneficial for the users to rewrite their code to ensure that no copies have to be performed. Hence, while most mutations are benign and can be handled automatically, the really complicated cases result in a user error, which lets them know that they likely want to restructure the program. This allows us to avoid introducing subtle and hard-to-find performance cliffs.

Performance focused implementation

Running deep learning algorithms efficiently from a Python interpreter is notoriously challenging: for instance, the global interpreter lock (The Python team, n.d.) effectively ensures that only one of any number of concurrent threads is running at any given time. Deep learning frameworks based on the construction of a static data-flow graph sidestep this problem by deferring the evaluation of the computation to a custom interpreter.

PyTorch solved the problem differently, by carefully optimizing every aspect of its execution while simultaneously empowering its users to easily leverage additional optimization strategies.

An efficient C++ core

Despite being closely integrated in the Python ecosystem, most of PyTorch is written in C++ to achieve high performance. This core libtorch library implements the tensor data structure, the GPU and CPU operators, and basic parallel primitives. It also provides the automatic differentiation system, including the gradient formulas for most built-in functions. This ensures that the computation of the derivatives of functions composed of core PyTorch operators is executed entirely in a multithreaded evaluator which does not require holding the Python global interpreter lock (The Python team, n.d.). Python bindings are generated using YAML meta-data files. An interesting side-effect of this approach is that it allowed our community to quickly create bindings to multiple other languages resulting in projects like NimTorch (Petrantoni and Wollenschläger, n.d.), hasktorch (Huang, Hashimoto, and Stites, n.d.) and others.

This design also allowed us to create first-class C++ bindings and modeling libraries that can be used in places where Python is inconvenient, such as the game engine for Starcraft (Synnaeve et al. 2018) or on mobile platforms. It is even possible to take the Python code describing a PyTorch model and run it without Python using the TorchScript engine (The PyTorch team, n.d.b).

Separate control and data flow

PyTorch maintains a strict separation between its control (i.e. program branches, loops) and data flow (i.e. tensors and the operations performed on them). The resolution of the control flow is handled by Python and optimized C++ code executed on the host CPU, and result in a linear sequence of operator invocations on the device. Operators can be run either on CPU or on GPU.

PyTorch is designed to execute operators asynchronously on GPU by leveraging the CUDA stream mechanism (Luitjens 2014) to queue CUDA kernel invocations to the GPUs hardware FIFO. This allows the system to overlap the execution of Python code on CPU with tensor operators on GPU. Because the tensor operations usually take a significant amount of time, this lets us saturate the GPU and reach peak performance even in an interpreted language with fairly high overhead like Python. Note that this mechanism is nearly invisible to the user. Unless they implement their own multi-stream primitives all of the CPU-GPU synchronization is handled by the library.

PyTorch could leverage a similar mechanism to also execute operators asynchronously on the CPU. However the costs of cross-thread communication and synchronization would negate the performance benefit of such an optimization.

Custom caching tensor allocator

Almost every operator must dynamically allocate an output tensor to hold the result of its execution. It is therefore critical to optimize the speed of the dynamic memory allocators. PyTorch can rely on optimized libraries (Berger et al. 2000; Evans May 2006; Ghemawat and Menage, n.d.) to handle this task on CPU. However, on GPU the cudaFree routine may block its caller until all previously queued work on all GPUs completes. To avoid this bottleneck, PyTorch implements a custom allocator which incrementally builds up a cache of CUDA memory and reassigns it to later allocations without further use of CUDA APIs. The incremental allocation is also crucial for better interoperability, because taking up all GPU memory ahead of time would prevent the user from utilizing other GPU-enabled Python packages.

To further improve its effectiveness, this allocator was tuned for the specific memory usage patterns of deep learning. For example, it rounds up allocations to multiples of 512 bytes to avoid fragmentation issues. Moreover, it maintains a distinct pool of memory for every CUDA stream (work queue).

The one-pool-per-stream design assumption simplifies the implementation and improves the performance of the allocator: because the CPU runs ahead of the GPU, memory is freed on the CPU before its last use on the GPU finishes. Since streams serialize execution, if the free precedes the reallocation on the CPU, the same order will occur on the GPU. So the allocator can reallocate memory freed on the CPU immediately as long as the new allocation is used on the same stream as the freed region. However, if an allocation was last used on one stream and then allocated on another, additional synchronization is needed.

The one-pool-per-stream design seems limiting since the allocations end up fragmented per stream, but in practice PyTorch almost never uses multiple streams. It is notoriously hard to write CUDA kernels in a way that would let them cooperatively share the GPU because exact scheduling is hardware controlled. In practice, kernel writers usually resort to monolithic kernels that combine multiple tasks. Data loading and distributed computing utilities are exceptions to the one stream design, and they carefully insert additional synchronization to avoid bad interactions with the allocator.

While this design is susceptible to certain corner cases, it almost never exhibits unwanted behaviors in practical code. Most of our users are not aware of its existence.

Multiprocessing

Due to the global interpreter lock (GIL) Python's default implementation does not allow concurrent threads to execute in parallel. To alleviate this problem, the Python community has established a standard multiprocessing module, containing a number of utilities that allow users to easily spawn child processes and implement basic inter-process communication primitives.

However, the implementation of the primitives uses the same form of serialization used for on-disk persistence, which is inefficient when dealing with large arrays. Hence, PyTorch extends the Python multiprocessing module into torch.multiprocessing, which is a drop-in replacement for the built in package and automatically moves the data of tensors sent to other processes to shared memory instead of sending it over the communication channel.

This design greatly improves performance and makes the process isolation weaker, resulting in a programming model which more closely resembles regular threaded programs. Users can easily implement heavily parallel programs that operate on independent GPUs but later synchronize gradients using all-reduce style primitives.

Another unique feature of this system is that it transparently handles sharing of CUDA tensors, making it easy to implement techniques like Hogwild (Recht et al. 2011).

Reference counting

Users often design their models to utilize all memory available during training, and increasing batch sizes is a common technique of speeding up the process. Therefore, to deliver great performance, PyTorch has to treat memory as a scarce resource that it needs to manage carefully.

Libraries with eager semantics have to manage tensor memory without knowing how it will be used in the future. Garbage collection is the typical way to handle this automatically because it has good amortized performance. In this approach, the runtime periodically investigates the state of the system, enumerates used objects and frees everything else. However, by deferring the deallocation, it causes the program to use more memory overall (Hertz and Berger 2005). Given the scarcity of GPU memory, these overheads are unacceptable. In fact, Torch7 utilized the garbage collector built into Lua, and a common anti-pattern among the users was to sprinkle the program with explicit triggers to the garbage collector, hoping that the memory errors go away.

PyTorch takes a different approach: it relies on a reference counting scheme to track the number of uses of each tensor, and frees the underlying memory immediately once this count reaches zero. Note that PyTorch tracks both references internal to the libtorch library and external references made by users in their Python code by integrating with Python's own reference counting mechanism. This ensures that memory is released exactly when tensors become unneeded.

One notable caveat is that we can only guarantee the desired performance characteristics in implementations of languages that either already utilize reference counting (CPython, Swift, but not PyPy or many scripting languages such as Lua), and those that allow for user-defined behavior for assignment, copies, and moves (e.g. C++, Rust). Bindings to implementations that do not satisfy those criteria will have to implement their own specialized memory management on top of PyTorch.

Evaluation

In this section we compare the performance of PyTorch with several other commonly-used deep learning libraries, and find that it achieves competitive performance across a range of tasks. All experiments were performed on a workstation with two Intel Xeon E5-2698 v4 CPUs and one NVIDIA Quadro GP100 GPU.

Asynchronous dataflow

We start by quantifying the ability of PyTorch to asynchronously execute dataflow on GPU. We use the built-in profiler (The PyTorch team, n.d.a) to instrument various benchmarks and record a timeline of the execution of a single training step.

Figure [fig:async_execution]{reference-type="ref" reference="fig:async_execution"} shows a representative timeline of execution for the first few operations of a ResNet-50 model. The host CPU which queues the work quickly outpaces the execution of the operators on the GPU. This allows PyTorch to achieve almost perfect device utilization. In this example, GPU execution takes around three times longer than CPU scheduling. The exact ratio depends on the relative performance of the host CPU and the GPU, as well as the number of elements in each tensor and the average arithmetic complexity of the floating point computations to be performed on the GPU.

::: {.center} image{width="\textwidth"} :::

Memory management

We used the NVIDIA profiler to trace the execution of the CUDA runtime as well as the execution of the CUDA kernels launched during one training iteration of the ResNet-50 model. As shown in Figure [fig:resnet_annotated_traces]{reference-type="ref" reference="fig:resnet_annotated_traces"}, the behavior of the first iteration differs significantly from that of subsequent ones. At first, calls to the CUDA memory management functions (cudaMalloc and cudaFree) slow down the execution quite dramatically by blocking the CPU thread for long periods of time, hence lowering the utilization of the GPU. This effect disappears in subsequent iterations as the PyTorch caching memory allocator starts reusing previously allocated regions.

::: {.center} image{width="\textwidth"} :::

Benchmarks

Finally, we can get an overall sense of single-machine eager mode performance of PyTorch by comparing it to three popular graph-based deep learning frameworks (CNTK, MXNet and TensorFlow), a define-by-run framework (Chainer), and production oriented platform (PaddlePaddle). The Appendix details all the steps needed to reproduce our setup.

Our results are summarized in Table 1{reference-type="ref" reference="detailed_perf_results"}. On all the benchmarks, the performance of PyTorch is within 17% of that of of the fastest framework. We attribute this result to the fact that these tools offload most of the computation to the same version of the cuDNN and cuBLAS libraries.

::: {#detailed_perf_results}

Framework Throughput (higher is better)
AlexNet VGG-19 ResNet-50 MobileNet GNMTv2 NCF
Chainer $778 \pm 15$ N/A $\textbf{219} \pm 1$ N/A N/A N/A
CNTK $845 \pm{8}$ $84 \pm{3}$ $210 \pm{1}$ N/A N/A N/A
MXNet $\textbf{1554} \pm 22$ $113 \pm 1$ $218 \pm 2$ $444 \pm 2$ N/A N/A
PaddlePaddle $933\pm{123}$ $112 \pm{2}$ $192 \pm{4}$ $\textbf{557}\pm{24}$ N/A N/A
TensorFlow $1422 \pm 27$ $66 \pm 2$ $200 \pm 1$ $216 \pm 15$ $9631 \pm 1.3%$ $4.8e6 \pm 2.9%$
PyTorch $1547 \pm 316$ $\textbf{119} \pm 1$ $212 \pm 2$ $463 \pm 17$ $\textbf{15512} \pm 4.8%$ $\textbf{5.4e6} \pm 3.4%$

Training speed for 6 models using 32bit floats. Throughput is measured in images per second for the AlexNet, VGG-19, ResNet-50, and MobileNet models, in tokens per second for the GNMTv2 model, and in samples per second for the NCF model. The fastest speed for each model is shown in bold. :::

Adoption

The validity of design decisions and their impact on ease-of-use is hard to measure. As a proxy, we tried to quantify how well the machine learning community received PyTorch by counting how often various machine learning tools (including Caffe, Chainer, CNTK, Keras, MXNet, PyTorch, TensorFlow, and Theano) are mentioned on arXiv e-Prints since the initial release of PyTorch in January 2017. In Figure [fig:pytorch_references]{reference-type="ref" reference="fig:pytorch_references"} we report the monthly number of mentions of the word "PyTorch" as a percentage of all mentions among these deep learning frameworks. We counted tools mentioned multiple times in a given paper only once, and made the search case insensitive to account for various spellings.

::: {.center} image{width="\linewidth"} :::

Conclusion and future work

PyTorch has become a popular tool in the deep learning research community by combining a focus on usability with careful performance considerations. In addition to continuing to support the latest trends and advances in deep learning, in the future we plan to continue to improve the speed and scalability of PyTorch. Most notably, we are working on the PyTorch JIT: a suite of tools that allow PyTorch programs to be executed outside of the Python interpreter where they can be further optimized. We also intend to improve support for distributed computation by providing efficient primitives for data parallelism as well as a Pythonic library for model parallelism based around remote procedure calls.

Acknowledgements

We are grateful to the PyTorch community for their feedback and contributions that greatly influenced the design and implementation of PyTorch. We thank all the PyTorch core team members, contributors and package maintainers including Ailing Zhang, Alex Suhan, Alfredo Mendoza, Alican Bozkurt, Andrew Tulloch, Ansha Yu, Anthony Shoumikhin, Bram Wasti, Brian Vaughan, Christian Puhrsch, David Reiss, David Riazati, Davide Libenzi, Dmytro Dzhulgakov, Dwaraj Rajagopal, Edward Yang, Elias Ellison, Fritz Obermeyer, George Zhang, Hao Lu, Hong Xu, Hung Duong, Igor Fedan, Ilia Cherniavskii, Iurii Zdebskyi, Ivan Kobzarev, James Reed, Jeff Smith, Jerry Chen, Jerry Zhang, Jiakai Liu, Johannes M. Dieterich, Karl Ostmo, Lin Qiao, Martin Yuan, Michael Suo, Mike Ruberry, Mikhail Zolothukhin, Mingzhe Li, Neeraj Pradhan, Nick Korovaiko, Owen Anderson, Pavel Belevich, Peter Johnson, Pritam Damania, Raghuraman Krishnamoorthi, Richard Zou, Roy Li, Rui Zhu, Sebastian Messmer, Shen Li, Simon Wang, Supriya Rao, Tao Xu, Thomas Viehmann, Vincent Quenneville-Belair, Vishwak Srinivasan, Vitaly Fedyunin, Wanchao Liang, Wei Yang, Will Feng, Xiaomeng Yang, Xiaoqiang Zheng, Xintao Chen, Yangqing Jia, Yanli Zhao, Yinghai Lu and Zafar Takhirov.

::: {#refs .references .csl-bib-body .hanging-indent} ::: {#ref-TF .csl-entry} Abadi, Martı́n, Ashish Agarwal, Paul Barham, Eugene Brevdo, Zhifeng Chen, Craig Citro, Greg S. Corrado, et al. 2015. "TensorFlow: Large-Scale Machine Learning on Heterogeneous Systems." https://www.tensorflow.org/. :::

::: {#ref-APL .csl-entry} Abrams, Philip S. 1970. "An APL Machine." PhD thesis, Stanford University. :::

::: {#ref-jax .csl-entry} al., Matthew Johnson et. 2018. "Jax." GitHub Repository. https://github.com/google/jax; GitHub. :::

::: {#ref-flux .csl-entry} al., Mike Innes et. 2018. "Flux.jl." GitHub Repository. https://github.com/FluxML/Flux.jl; GitHub. :::

::: {#ref-autodiff_survey .csl-entry} Baydin, Atilim Gunes, Barak A. Pearlmutter, Alexey Andreyevich Radul, and Jeffrey Mark Siskind. 2017. "Automatic Differentiation in Machine Learning: A Survey." J. Mach. Learn. Res. 18 (1): 5595--5637. http://dl.acm.org/citation.cfm?id=3122009.3242010. :::

::: {#ref-hoard .csl-entry} Berger, Emery D., Kathryn S. McKinley, Robert D. Blumofe, and Paul R. Wilson. 2000. "Hoard: A Scalable Memory Allocator for Multithreaded Applications." In Proceedings of the Ninth International Conference on Architectural Support for Programming Languages and Operating Systems, 117--28. ASPLOS IX. New York, NY, USA: ACM. https://doi.org/10.1145/378993.379232. :::

::: {#ref-Julia .csl-entry} Bezanson, Jeff, Alan Edelman, Stefan Karpinski, and Viral B Shah. 2017. "Julia: A Fresh Approach to Numerical Computing." SIAM Review 59 (1): 65--98. https://doi.org/10.1137/141000671. :::

::: {#ref-cudnn .csl-entry} Chetlur, Sharan, Cliff Woolley, Philippe Vandermersch, Jonathan D. Cohen, John Tran, Bryan Catanzaro, and Evan Shelhamer. 2014. "cuDNN: Efficient Primitives for Deep Learning." CoRR abs/1410.0759. :::

::: {#ref-Torch .csl-entry} Collobert, Ronan, Samy Bengio, and Johnny Mariéthoz. 2002. "Torch: A Modular Machine Learning Software Library." Idiap. :::

::: {#ref-Torch7 .csl-entry} Collobert, Ronan, Koray Kavukcuoglu, and Clément Farabet. 2011. "Torch7: A Matlab-Like Environment for Machine Learning." In NIPS 2011. :::

::: {#ref-dlpack .csl-entry} DMLC. n.d. "DLPack: Open in Memory Tensor Structure." :::

::: {#ref-jemalloc .csl-entry} Evans, J. May 2006. "A Scalable Concurrent Malloc(3) Implementation for FreeBSD." In In BSDCan --- the Technical BSD Conference. Ottawa, Canada. http://people.freebsd.org/˜jasone/jemalloc/bsdcan2006/jemalloc.pdf. :::

::: {#ref-worse_is_better .csl-entry} Gabriel, Richard. n.d. "The Rise of Worse Is Better." :::

::: {#ref-tcmalloc .csl-entry} Ghemawat, S., and P. Menage. n.d. "Tcmalloc: Thread-Caching Malloc." http://goog-perftools.sourceforge.net/doc/tcmalloc.html. :::

::: {#ref-eigenweb .csl-entry} Guennebaud, Gaël, Benoît Jacob, et al. 2010. "Eigen V3." http://eigen.tuxfamily.org. :::

::: {#ref-garbage_collection .csl-entry} Hertz, Matthew, and Emery D. Berger. 2005. "Quantifying the Performance of Garbage Collection Vs. Explicit Memory Management." In Proceedings of the 20th Annual ACM SIGPLAN Conference on Object-Oriented Programming, Systems, Languages, and Applications, 313--26. OOPSLA '05. New York, NY, USA: ACM. https://doi.org/10.1145/1094811.1094836. :::

::: {#ref-hasktorch .csl-entry} Huang, Austin, Junji Hashimoto, and Sam Stites. n.d. "HaskTorch." :::

::: {#ref-Caffe .csl-entry} "Jia, Yangqing, Evan Shelhamer, Jeff Donahue, Sergey Karayev, Jonathan Long, Ross Girshick, Sergio Guadarrama, and Trevor" Darrell. "2014". ""Caffe: Convolutional Architecture for Fast Feature Embedding"." "arXiv Preprint arXiv:1408.5093", "2014". :::

::: {#ref-SciPy .csl-entry} Jones, Eric, Travis Oliphant, Pearu Peterson, et al. 2001--. "SciPy: Open Source Scientific Tools for Python." :::

::: {#ref-maxdnn .csl-entry} Lavin, Andrew. 2015. "maxDNN: An Efficient Convolution Kernel for Deep Learning with Maxwell GPUs." :::

::: {#ref-fast_cnn .csl-entry} Lavin, Andrew, and Scott Gray. 2016. "Fast Algorithms for Convolutional Neural Networks." 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 4013--21. :::

::: {#ref-mnist .csl-entry} LeCun, Yann, and Corinna Cortes. n.d. "MNIST Handwritten Digit Database." http://yann.lecun.com/exdb/mnist/. http://yann.lecun.com/exdb/mnist/. :::

::: {#ref-Lush .csl-entry} LeCun, Y, and L Bottou. 2002. "Lush Reference Manual." code available at http://lush.sourceforge.net. :::

::: {#ref-Leuck-dual-numbers .csl-entry} Leuck, Holger, and Hans-Hellmut Nagel. 1999. "Automatic Differentiation Facilitates OF-Integration into Steering-Angle-Based Road Vehicle Tracking." In 1999 Conference on Computer Vision and Pattern Recognition (CVPR '99), 23-25 June 1999, Ft. Collins, CO, USA, 2360--65. https://doi.org/10.1109/CVPR.1999.784659. :::

::: {#ref-cuda_stream .csl-entry} Luitjens, Justin. 2014. "CUDA Streams." http://on-demand.gputechconf.com/gtc/2014/presentations/S4158-cuda-streams-best-practices-common-pitfalls.pdf. :::

::: {#ref-maclaurin2016phd .csl-entry} Maclaurin, Dougal. 2016. "Modeling, Inference and Optimization with Composable Differentiable Procedures." PhD thesis, Harvard University. :::

::: {#ref-Matlab .csl-entry} MATLAB and Statistics Toolbox. n.d. Natick, Massachusetts, United States: The MathWorks, Inc. :::

::: {#ref-Pandas .csl-entry} McKinney, Wes. 2010. "Data Structures for Statistical Computing in Python." In Proceedings of the 9th Python in Science Conference, 51-56. :::

::: {#ref-DyNet .csl-entry} Neubig, G., C. Dyer, Y. Goldberg, A. Matthews, W. Ammar, A. Anastasopoulos, M. Ballesteros, et al. 2017. "DyNet: The Dynamic Neural Network Toolkit." ArXiv e-Prints, January. https://arxiv.org/abs/1701.03980. :::

::: {#ref-Numpy .csl-entry} Oliphant, Travis. 2006. "NumPy: A Guide to NumPy." USA: Trelgol Publishing. :::

::: {#ref-pytorch_autodiff .csl-entry} Paszke, Adam, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. 2017. "Automatic Differentiation in PyTorch." In NIPS Workshop. :::

::: {#ref-nimtorch .csl-entry} Petrantoni, Giovanni, and Jörg Wollenschläger. n.d. "NimTorch." :::

::: {#ref-Piponi-dual-numbers .csl-entry} Piponi, Dan. 2004. "Automatic Differentiation, C++ Templates, and Photogrammetry." J. Graphics, GPU, & Game Tools 9 (4): 41--55. https://doi.org/10.1080/10867651.2004.10504901. :::

::: {#ref-R .csl-entry} R Core Team. n.d. R: A Language and Environment for Statistical Computing. Vienna, Austria: R Foundation for Statistical Computing. http://www.R-project.org/. :::

::: {#ref-Hogwild .csl-entry} Recht, Benjamin, Christopher Ré, Stephen J. Wright, and Feng Niu. 2011. "Hogwild: A Lock-Free Approach to Parallelizing Stochastic Gradient Descent." In Advances in Neural Information Processing Systems 24: 25th Annual Conference on Neural Information Processing Systems 2011. Proceedings of a Meeting Held 12-14 December 2011, Granada, Spain., 693--701. http://papers.nips.cc/paper/4390-hogwild-a-lock-free-approach-to-parallelizing-stochastic-gradient-descent. :::

::: {#ref-CNTK .csl-entry} Seide, Frank, and Amit Agarwal. 2016. "CNTK: Microsoft's Open-Source Deep-Learning Toolkit." In Proceedings of the 22Nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2135--35. KDD '16. New York, NY, USA: ACM. https://doi.org/10.1145/2939672.2945397. :::

::: {#ref-EBLearn .csl-entry} Sermanet, Pierre, Koray Kavukcuoglu, and Yann LeCun. 2009. "Eblearn: Open-Source Energy-Based Learning in c++." In 2009 21st IEEE International Conference on Tools with Artificial Intelligence, 693--97. IEEE. :::

::: {#ref-starcraft_pytorch .csl-entry} Synnaeve, G., Z. Lin, J. Gehring, D. Gant, V. Mella, V. Khalidov, N. Carion, and N. Usunier. 2018. "Forward Modeling for Partial Observation Strategy Games - a Starcraft Defogger." In Advances in Neural Information Processing Systems, 10761--71. :::

::: {#ref-python_gil .csl-entry} team, The Python. n.d. "The CPython Global Interpreter Lock." :::

::: {#ref-autograd_profiler .csl-entry} team, The PyTorch. n.d.a. Pytorch Autograd Profiler. :::

::: {#ref-torchscript .csl-entry} ---------. n.d.b. Torch Script. :::

::: {#ref-Theano .csl-entry} Theano Development Team. 2016. "[Theano: A Python framework for fast computation of mathematical expressions]{.nocase}." arXiv e-Prints abs/1605.02688 (May). http://arxiv.org/abs/1605.02688. :::

::: {#ref-Chainer .csl-entry} Tokui, Seiya, Kenta Oono, Shohei Hido, and Justin Clayton. 2015. "Chainer: A Next-Generation Open Source Framework for Deep Learning." In Proceedings of Workshop on Machine Learning Systems (LearningSys) in the Twenty-Ninth Annual Conference on Neural Information Processing Systems (NIPS). http://learningsys.org/papers/LearningSys_2015_paper_33.pdf. :::

::: {#ref-starcraft2 .csl-entry} Vinyals, Oriol, Timo Ewalds, Sergey Bartunov, Petko Georgiev, Alexander Sasha Vezhnevets, Michelle Yeo, Alireza Makhzani, et al. 2017. "StarCraft II: A New Challenge for Reinforcement Learning." CoRR abs/1708.04782. http://arxiv.org/abs/1708.04782. ::: :::

root[124] (1:1-850:1, 0-43048)
├─0 heading[1] (1:1-1:15, 0-14)
│ │ depth: 1
│ └─0 text "Introduction" (1:3-1:15, 2-14)
├─1 paragraph[1] (3:1-13:13, 16-713)
│ └─0 text "With the increased interest in deep learning in recent years, there has\nbeen an explosion of machine learning tools. Many popular frameworks\nsuch as Caffe (\"Jia et al. \"2014\"), CNTK (Seide and Agarwal 2016),\nTensorFlow (Abadi et al. 2015), and Theano (Theano Development Team\n2016), construct a static dataflow graph that represents the computation\nand which can then be applied repeatedly to batches of data. This\napproach provides visibility into the whole computation ahead of time,\nand can theoretically be leveraged to improve performance and\nscalability. However, it comes at the cost of ease of use, ease of\ndebugging, and flexibility of the types of computation that can be\nrepresented." (3:1-13:13, 16-713)
├─2 paragraph[1] (15:1-20:42, 715-1091)
│ └─0 text "Prior work has recognized the value of dynamic eager execution for deep\nlearning, and some recent frameworks implement this define-by-run\napproach, but do so either at the cost of performance (Chainer (Tokui et\nal. 2015)) or using a less expressive, faster language\n(Torch (Collobert, Bengio, and Mariéthoz 2002), DyNet (Neubig et al.\n2017)), which limits their applicability." (15:1-20:42, 715-1091)
├─3 paragraph[1] (22:1-29:66, 1093-1644)
│ └─0 text "However, with careful implementation and design choices, dynamic eager\nexecution can be achieved largely without sacrificing performance. This\npaper introduces PyTorch, a Python library that performs immediate\nexecution of dynamic tensor computations with automatic differentiation\nand GPU acceleration, and does so while maintaining performance\ncomparable to the fastest current libraries for deep learning. This\ncombination has turned out to be very popular in the research community\nwith, for instance, 296 ICLR 2019 submissions mentioning PyTorch." (22:1-29:66, 1093-1644)
├─4 heading[1] (31:1-31:13, 1646-1658)
│ │ depth: 1
│ └─0 text "Background" (31:3-31:13, 1648-1658)
├─5 paragraph[1] (33:1-34:29, 1660-1755)
│ └─0 text "Four major trends in scientific computing have become increasingly\nimportant for deep learning." (33:1-34:29, 1660-1755)
├─6 paragraph[5] (36:1-45:35, 1757-2423)
│ ├─0 text "First, starting in the 1960s, the development of domain specific\nlanguages such as APL (Abrams 1970), MATLAB (" (36:1-37:46, 1757-1867)
│ ├─1 emphasis[1] (37:46-38:9, 1867-1898)
│ │ └─0 text "MATLAB and Statistics\nToolbox" (37:47-38:8, 1868-1897)
│ ├─2 text ", n.d.), R (R Core Team, n.d.) and Julia (Bezanson et al. 2017),\nturned multidimensional arrays (often referred to as tensors) into\nfirst-class objects supported by a comprehensive set of mathematical\nprimitives (or operators) to manipulate them. Separately, libraries such\nas NumPy(Oliphant 2006), Torch(Collobert, Bengio, and Mariéthoz 2002),\nEigen(Guennebaud, Jacob, et al. 2010) and Lush(Y. LeCun and Bottou 2002)\nmade " (38:9-44:6, 1898-2321)
│ ├─3 strong[1] (44:6-44:33, 2321-2348)
│ │ └─0 text "array-based programming" (44:8-44:31, 2323-2346)
│ └─4 text " productive in general purpose languages\nsuch as Python, Lisp, C++ and Lua." (44:33-45:35, 2348-2423)
├─7 paragraph[3] (47:1-56:34, 2425-3081)
│ ├─0 text "Second, the development of " (47:1-47:28, 2425-2452)
│ ├─1 strong[1] (47:28-47:57, 2452-2481)
│ │ └─0 text "automatic differentiation" (47:30-47:55, 2454-2479)
│ └─2 text " (Baydin et al.\n2017) made it possible to fully automate the daunting labor of computing\nderivatives. This made it significantly easier to experiment with\ndifferent machine learning approaches while still allowing for efficient\ngradient based optimization. The autograd (Maclaurin 2016) package\npopularized the use of this technique for NumPy arrays, and similar\napproaches are used in frameworks such as Chainer (Tokui et al. 2015),\nDyNet (Neubig et al. 2017), Lush (Y. LeCun and Bottou 2002),\nTorch (Collobert, Bengio, and Mariéthoz 2002), Jax (M. J. et. al. 2018)\nand Flux.jl (M. I. et. al. 2018)." (47:57-56:34, 2481-3081)
├─8 paragraph[5] (58:1-78:9, 3083-4454)
│ ├─0 text "Third, with the advent of the free software movement, the scientific\ncommunity moved away from closed proprietary software such as\nMatlab(" (58:1-60:8, 3083-3221)
│ ├─1 emphasis[1] (60:8-60:39, 3221-3252)
│ │ └─0 text "MATLAB and Statistics Toolbox" (60:9-60:38, 3222-3251)
│ ├─2 text ", n.d.), and towards the\n" (60:39-61:1, 3252-3277)
│ ├─3 strong[1] (61:1-61:33, 3277-3309)
│ │ └─0 text "open-source Python ecosystem" (61:3-61:31, 3279-3307)
│ └─4 text " with packages like NumPy (Oliphant\n2006), SciPy (Jones et al. 2001--), and Pandas (McKinney 2010). This\nfulfilled most of the numerical analysis needs of researchers while\nallowing them to take advantage of a vast repository of libraries to\nhandle dataset preprocessing, statistical analysis, plotting, and more.\nMoreover, the openness, interoperability, and flexibility of free\nsoftware fostered the development of vibrant communities that could\nquickly address new or changing needs by extending the existing\nfunctionality of a library or if needed by developing and releasing\nbrand new ones. While there is a rich offering of open-source software\nfor neural networks in languages other than Python, starting with\nLush (Y. LeCun and Bottou 2002) in Lisp, Torch (Collobert, Bengio, and\nMariéthoz 2002) in C++, Objective-C and Lua, EBLearn (Sermanet,\nKavukcuoglu, and LeCun 2009) in C++, Caffe (\"Jia et al. \"2014\") in\nC++, the network effects of a large ecosystem such as Python made it an\nessential skill to jumpstart one's research. Hence, since 2014, most\ndeep learning frameworks converged on a Python interface as an essential\nfeature." (61:33-78:9, 3309-4454)
├─9 paragraph[3] (80:1-88:36, 4456-5037)
│ ├─0 text "Finally, the availability and commoditization of general-purpose\nmassively parallel hardware such as GPUs provided the computing power\nrequired by deep learning methods. Specialized libraries such as\ncuDNN (Chetlur et al. 2014), along with a body of academic work (such as\n(Lavin 2015) and (Lavin and Gray 2016)), produced a set of\nhigh-performance reusable deep learning kernels that enabled frameworks\nsuch as Caffe (\"Jia et al. \"2014\"), Torch7 (Collobert, Kavukcuoglu,\nand Farabet 2011), or TensorFlow (Abadi et al. 2015) to take advantage\nof these " (80:1-88:10, 4456-5011)
│ ├─1 strong[1] (88:10-88:35, 5011-5036)
│ │ └─0 text "hardware accelerators" (88:12-88:33, 5013-5034)
│ └─2 text "." (88:35-88:36, 5036-5037)
├─10 paragraph[1] (90:1-92:52, 5039-5220)
│ └─0 text "PyTorch builds on these trends by providing an array-based programming\nmodel accelerated by GPUs and differentiable via automatic\ndifferentiation integrated in the Python ecosystem." (90:1-92:52, 5039-5220)
├─11 heading[1] (94:1-94:20, 5222-5241)
│ │ depth: 1
│ └─0 text "Design principles" (94:3-94:20, 5224-5241)
├─12 paragraph[1] (96:1-98:13, 5243-5396)
│ └─0 text "PyTorch's success stems from weaving previous ideas into a design that\nbalances speed and ease of use. There are four main principles behind\nour choices:" (96:1-98:13, 5243-5396)
├─13 paragraph[2] (100:1-105:57, 5398-5797)
│ ├─0 strong[1] (100:1-100:16, 5398-5413)
│ │ └─0 text "Be Pythonic" (100:3-100:14, 5400-5411)
│ └─1 text " Data scientists are familiar with the Python language,\nits programming model, and its tools. PyTorch should be a first-class\nmember of that ecosystem. It follows the commonly established design\ngoals of keeping interfaces simple and consistent, ideally with one\nidiomatic way of doing things. It also integrates naturally with\nstandard plotting, debugging, and data processing tools." (100:16-105:57, 5413-5797)
├─14 paragraph[2] (107:1-111:48, 5799-6114)
│ ├─0 strong[1] (107:1-107:26, 5799-5824)
│ │ └─0 text "Put researchers first" (107:3-107:24, 5801-5822)
│ └─1 text " PyTorch strives to make writing models, data\nloaders, and optimizers as easy and productive as possible. The\ncomplexity inherent to machine learning should be handled internally by\nthe PyTorch library and hidden behind intuitive APIs free of\nside-effects and unexpected performance cliffs." (107:26-111:48, 5824-6114)
├─15 paragraph[4] (113:1-121:15, 6116-6680)
│ ├─0 strong[1] (113:1-113:34, 6116-6149)
│ │ └─0 text "Provide pragmatic performance" (113:3-113:32, 6118-6147)
│ ├─1 text " To be useful, PyTorch needs to deliver\ncompelling performance, although not at the expense of simplicity and\nease of use. Trading 10% of speed for a significantly simpler to use\nmodel is acceptable; 100% is not. Therefore, its " (113:34-116:50, 6149-6377)
│ ├─2 emphasis[1] (116:50-116:66, 6377-6393)
│ │ └─0 text "implementation" (116:51-116:65, 6378-6392)
│ └─3 text "\naccepts added complexity in order to deliver that performance.\nAdditionally, providing tools that allow researchers to manually control\nthe execution of their code will empower them to find their own\nperformance improvements independent of those that the library provides\nautomatically." (116:66-121:15, 6393-6680)
├─16 paragraph[2] (123:1-129:29, 6682-7131)
│ ├─0 strong[1] (123:1-123:20, 6682-6701)
│ │ └─0 text "Worse is better" (123:3-123:18, 6684-6699)
│ └─1 text " (Gabriel, n.d.) Given a fixed amount of engineering\nresources, and all else being equal, the time saved by keeping the\ninternal implementation of PyTorch simple can be used to implement\nadditional features, adapt to new situations, and keep up with the fast\npace of progress in the field of AI. Therefore it is better to have a\nsimple but slightly incomplete solution than a comprehensive but complex\nand hard to maintain design." (123:20-129:29, 6701-7131)
├─17 heading[1] (131:1-131:27, 7133-7159)
│ │ depth: 1
│ └─0 text "Usability centric design" (131:3-131:27, 7135-7159)
├─18 heading[1] (133:1-133:49, 7161-7209)
│ │ depth: 2
│ └─0 text "Deep learning models are just Python programs" (133:4-133:49, 7164-7209)
├─19 paragraph[1] (135:1-148:52, 7211-8140)
│ └─0 text "In a surprisingly short amount of time, machine learning grew from\nrecognizing individual digits (Yann LeCun and Cortes, n.d.) into\nautonomously playing StarCraft (Vinyals et al. 2017). Consequently, the\nneural networks themselves evolved rapidly from simple sequences of feed\nforward layers into incredibly varied numerical programs often composed\nof many loops and recursive functions. To support this growing\ncomplexity, PyTorch foregoes the potential benefits of a\ngraph-metaprogramming based approach to preserve the imperative\nprogramming model of Python. This design was pioneered for model\nauthoring by Chainer(Tokui et al. 2015) and Dynet(Neubig et al. 2017).\nPyTorch extends this to all aspects of deep learning workflows. Defining\nlayers, composing models, loading data, running optimizers, and\nparallelizing the training process are all expressed using the familiar\nconcepts developed for general purpose programming." (135:1-148:52, 7211-8140)
├─20 paragraph[3] (150:1-165:7, 8142-9138)
│ ├─0 text "This solution ensures that any new potential neural network architecture\ncan be easily implemented with PyTorch. For instance, layers (which in\nmodern machine learning should really be understood as stateful\nfunctions with implicit parameters) are typically expressed as Python\nclasses whose constructors create and initialize their parameters, and\nwhose forward methods process an input activation. Similarly, models are\nusually represented as classes that compose individual layers, but let\nus state again that nothing forces the user to structure their code in\nthat way. Listing\n" (150:1-159:1, 8142-8724)
│ ├─1 link[1] (159:1-159:42, 8724-8765)
│ │ │ title: null
│ │ │ url: "#lst:code_example"
│ │ └─0 text "[lst:code_example]" (159:2-159:22, 8725-8745)
│ └─2 text "{reference-type=\"ref\"\nreference=\"lst:code_example\"} demonstrates how an entire model can be\ncreated by composing functionality provided by PyTorch such as 2d\nconvolution, matrix multiplication, dropout, and softmax to classify\ngray-scale images. Note that linear layers are of course part of the\nlibrary, but we show an example implementation to highlight how simple\nit is." (159:42-165:7, 8765-9138)
├─21 paragraph[1] (167:1-169:4, 9140-9163)
│ └─0 text "::: {.parcolumns}\n2\n:::" (167:1-169:4, 9140-9163)
├─22 paragraph[1] (171:1-171:47, 9165-9211)
│ └─0 text "[]{#lst:code_example label=\"lst:code_example\"}" (171:1-171:47, 9165-9211)
├─23 paragraph[3] (173:1-181:65, 9213-9829)
│ ├─0 text "This \"everything is a just a program\" philosophy is not limited to just\nthe models, and applies to optimizers and data loaders as well. This\nfacilitates the experimentation of new training techniques. For example,\nto implement the very popular generative adversarial networks, one needs\nto specify two separate models (the generator and the discriminator),\nand two loss functions that depend on both models at the same time.\nRigid APIs would struggle with this setup, but the simple design\nemployed in PyTorch easily adapts to this setting as shown in\nListing " (173:1-181:9, 9213-9773)
│ ├─1 link[1] (181:9-181:22, 9773-9786)
│ │ │ title: null
│ │ │ url: "#lst:gan"
│ │ └─0 text "1" (181:10-181:11, 9774-9775)
│ └─2 text "{reference-type=\"ref\" reference=\"lst:gan\"}." (181:22-181:65, 9786-9829)
├─24 html "<figure id=\"lst:gan\">\n<div class=\"sourceCode\" id=\"cb1\" data-fontsize=\"\\small\"><pre\nclass=\"sourceCode python\"><code class=\"sourceCode python\"><span id=\"cb1-1\"><a href=\"#cb1-1\" aria-hidden=\"true\" tabindex=\"-1\"></a>discriminator <span class=\"op\">=</span> create_discriminator()</span>\n<span id=\"cb1-2\"><a href=\"#cb1-2\" aria-hidden=\"true\" tabindex=\"-1\"></a>generator <span class=\"op\">=</span> create_generator()</span>\n<span id=\"cb1-3\"><a href=\"#cb1-3\" aria-hidden=\"true\" tabindex=\"-1\"></a>optimD <span class=\"op\">=</span> optim.Adam(discriminator.parameters())</span>\n<span id=\"cb1-4\"><a href=\"#cb1-4\" aria-hidden=\"true\" tabindex=\"-1\"></a>optimG <span class=\"op\">=</span> optim.Adam(generator.parameters())</span>\n<span id=\"cb1-5\"><a href=\"#cb1-5\" aria-hidden=\"true\" tabindex=\"-1\"></a></span>\n<span id=\"cb1-6\"><a href=\"#cb1-6\" aria-hidden=\"true\" tabindex=\"-1\"></a><span class=\"kw\">def</span> step(real_sample):</span>\n<span id=\"cb1-7\"><a href=\"#cb1-7\" aria-hidden=\"true\" tabindex=\"-1\"></a> <span class=\"co\"># (1) Update Discriminator</span></span>\n<span id=\"cb1-8\"><a href=\"#cb1-8\" aria-hidden=\"true\" tabindex=\"-1\"></a> errD_real <span class=\"op\">=</span> loss(discriminator(real_sample), real_label)</span>\n<span id=\"cb1-9\"><a href=\"#cb1-9\" aria-hidden=\"true\" tabindex=\"-1\"></a> errD_real.backward()</span>\n<span id=\"cb1-10\"><a href=\"#cb1-10\" aria-hidden=\"true\" tabindex=\"-1\"></a> fake <span class=\"op\">=</span> generator(get_noise())</span>\n<span id=\"cb1-11\"><a href=\"#cb1-11\" aria-hidden=\"true\" tabindex=\"-1\"></a> errD_fake <span class=\"op\">=</span> loss(discriminator(fake.detach(), fake_label)</span>\n<span id=\"cb1-12\"><a href=\"#cb1-12\" aria-hidden=\"true\" tabindex=\"-1\"></a> errD_fake.backward()</span>\n<span id=\"cb1-13\"><a href=\"#cb1-13\" aria-hidden=\"true\" tabindex=\"-1\"></a> optimD.step()</span>\n<span id=\"cb1-14\"><a href=\"#cb1-14\" aria-hidden=\"true\" tabindex=\"-1\"></a> <span class=\"co\"># (2) Update Generator</span></span>\n<span id=\"cb1-15\"><a href=\"#cb1-15\" aria-hidden=\"true\" tabindex=\"-1\"></a> errG <span class=\"op\">=</span> loss(discriminator(fake), real_label)</span>\n<span id=\"cb1-16\"><a href=\"#cb1-16\" aria-hidden=\"true\" tabindex=\"-1\"></a> errG.backward()</span>\n<span id=\"cb1-17\"><a href=\"#cb1-17\" aria-hidden=\"true\" tabindex=\"-1\"></a> optimG.step()</span></code></pre></div>\n<p><span id=\"lst:gan\" label=\"lst:gan\"></span></p>\n</figure>" (183:1-203:10, 9831-12190)
├─25 paragraph[1] (205:1-211:43, 12192-12644)
│ └─0 text "Since PyTorch programs execute eagerly, all the features of Python are\navailable throughout the whole design process. Print statements,\nstandard debuggers, and common visualization tools like matplotlib all\nwork as expected. Users do not have to wait for lengthy compilation\nbefore they can start running their programs, and more importantly\nintermediate computations can be observed to understand how a model\nworks and whether its results are correct." (205:1-211:43, 12192-12644)
├─26 heading[1] (213:1-213:38, 12646-12683)
│ │ depth: 2
│ └─0 text "Interoperability and extensibility" (213:4-213:38, 12649-12683)
├─27 paragraph[5] (215:1-226:43, 12685-13508)
│ ├─0 text "Easy and efficient interoperability is one of the top priorities for\nPyTorch because it opens the possibility to leverage the rich ecosystem\nof Python libraries as part of user programs. Hence, PyTorch allows for\nbidirectional exchange of data with external libraries. For example, it\nprovides a mechanism to convert between NumPy arrays and PyTorch tensors\nusing the " (215:1-220:11, 12685-13053)
│ ├─1 inlineCode "torch.from_numpy()" (220:11-220:31, 13053-13073)
│ ├─2 text " function and " (220:31-220:45, 13073-13087)
│ ├─3 inlineCode ".numpy()" (220:45-220:55, 13087-13097)
│ └─4 text " tensor method.\nSimilar functionality is also available to exchange data stored using\nthe DLPack (DMLC, n.d.) format. Note that this exchange happens in both\ncases without any data copying -- objects on both sides only describe\nhow to interpret a memory region which is shared among them. Hence,\nthose operations are actually extremely cheap, and take constant time no\nmatter how large the converted arrays are." (220:55-226:43, 13097-13508)
├─28 paragraph[15] (228:1-242:30, 13510-14504)
│ ├─0 text "Moreover, many of the critical systems are designed specifically to be\nextensible. For instance, the automatic differentiation system allows\nusers to add support for custom differentiable functions. To do that\nusers can define a new subclass of " (228:1-231:36, 13510-13755)
│ ├─1 inlineCode "torch.autograd.Function" (231:36-231:61, 13755-13780)
│ ├─2 text " that\nimplements " (231:61-232:12, 13780-13797)
│ ├─3 inlineCode "forward()" (232:12-232:23, 13797-13808)
│ ├─4 text " and " (232:23-232:28, 13808-13813)
│ ├─5 inlineCode "backward()" (232:28-232:40, 13813-13825)
│ ├─6 text " methods, which specify the\nfunction and its derivative (or more formally the vector-Jacobian\nproduct). Similarly new datasets can be added by subclassing\n" (232:40-235:1, 13825-13980)
│ ├─7 inlineCode "torch.utils.data.Dataset" (235:1-235:27, 13980-14006)
│ ├─8 text " and implementing two methods: " (235:27-235:58, 14006-14037)
│ ├─9 inlineCode "__getitem__" (235:58-235:71, 14037-14050)
│ ├─10 text "\n(the indexing operator) and " (235:71-236:29, 14050-14079)
│ ├─11 inlineCode "__len__" (236:29-236:38, 14079-14088)
│ ├─12 text " (the length operator), making\ndatasets behave like (possibly lazy) lists. How these work is completely\nup to the implementer, and many users leverage other Python packages for\ndata loading. The " (236:38-239:19, 14088-14283)
│ ├─13 inlineCode "DataLoader" (239:19-239:31, 14283-14295)
│ └─14 text " class consumes objects conforming to this\ninterface and provides an iterator over the data which takes care of\nshuffling, batching, parallelization, and management of pinned CUDA\nmemory to improve throughput." (239:31-242:30, 14295-14504)
├─29 paragraph[1] (244:1-247:64, 14506-14773)
│ └─0 text "Most importantly, users are free to replace any component of PyTorch\nthat does not meet the needs or performance requirements of their\nproject. They are all designed to be completely interchangeable, and\nPyTorch takes great care not to impose any particular solution." (244:1-247:64, 14506-14773)
├─30 heading[1] (249:1-249:29, 14775-14803)
│ │ depth: 2
│ └─0 text "Automatic differentiation" (249:4-249:29, 14778-14803)
├─31 paragraph[1] (251:1-265:62, 14805-15837)
│ └─0 text "Since gradient based optimization is vital to deep learning, PyTorch\nmust be able to automatically compute gradients of models specified by\nour users, and those can be arbitrary Python programs. However, Python\nis a dynamic programming language that allows changing most behaviors at\nruntime, making ahead of time source-to-source differentiation\ncumbersome. Instead, PyTorch uses the operator overloading approach,\nwhich builds up a representation of the computed function every time it\nis executed. In its current implementation (Paszke et al. 2017), PyTorch\nperforms reverse-mode automatic differentiation, which computes the\ngradient of a scalar output with respect to a multivariate input.\nDifferentiating functions with more outputs than inputs is more\nefficiently executed using forward-mode automatic differentiation, but\nthis use case is less common for machine learning applications. PyTorch\ncan be easily extended to perform forward-mode differentiation using\narray-level dual numbers (Piponi 2004; Leuck and Nagel 1999)." (251:1-265:62, 14805-15837)
├─32 paragraph[1] (267:1-279:62, 15839-16750)
│ └─0 text "Another interesting and uncommon feature of our system is that it can\ndifferentiate through code employing mutation on tensors, which is one\nof the basic building blocks of imperative programs. To ensure safety,\nwe have implemented a versioning system for tensors, which lets us track\ntheir modifications and ensure that we always use the data we expect.\nOne interesting tradeoff is that while we could utilize techniques like\ncopy-on-write to support arbitrary programs, we chose to not go down\nthis path, as performance-wise it is usually beneficial for the users to\nrewrite their code to ensure that no copies have to be performed. Hence,\nwhile most mutations are benign and can be handled automatically, the\nreally complicated cases result in a user error, which lets them know\nthat they likely want to restructure the program. This allows us to\navoid introducing subtle and hard-to-find performance cliffs." (267:1-279:62, 15839-16750)
├─33 heading[1] (281:1-281:37, 16752-16788)
│ │ depth: 1
│ └─0 text "Performance focused implementation" (281:3-281:37, 16754-16788)
├─34 paragraph[1] (283:1-289:22, 16790-17227)
│ └─0 text "Running deep learning algorithms efficiently from a Python interpreter\nis notoriously challenging: for instance, the global interpreter\nlock (The Python team, n.d.) effectively ensures that only one of any\nnumber of concurrent threads is running at any given time. Deep learning\nframeworks based on the construction of a static data-flow graph\nsidestep this problem by deferring the evaluation of the computation to\na custom interpreter." (283:1-289:22, 16790-17227)
├─35 paragraph[1] (291:1-293:52, 17229-17419)
│ └─0 text "PyTorch solved the problem differently, by carefully optimizing every\naspect of its execution while simultaneously empowering its users to\neasily leverage additional optimization strategies." (291:1-293:52, 17229-17419)
├─36 heading[1] (295:1-295:25, 17421-17445)
│ │ depth: 2
│ └─0 text "An efficient C++ core" (295:4-295:25, 17424-17445)
├─37 paragraph[3] (297:1-310:18, 17447-18374)
│ ├─0 text "Despite being closely integrated in the Python ecosystem, most of\nPyTorch is written in C++ to achieve high performance. This core\n" (297:1-299:1, 17447-17578)
│ ├─1 inlineCode "libtorch" (299:1-299:11, 17578-17588)
│ └─2 text " library implements the tensor data structure, the GPU and CPU\noperators, and basic parallel primitives. It also provides the automatic\ndifferentiation system, including the gradient formulas for most\nbuilt-in functions. This ensures that the computation of the derivatives\nof functions composed of core PyTorch operators is executed entirely in\na multithreaded evaluator which does not require holding the Python\nglobal interpreter lock (The Python team, n.d.). Python bindings are\ngenerated using YAML meta-data files. An interesting side-effect of this\napproach is that it allowed our community to quickly create bindings to\nmultiple other languages resulting in projects like NimTorch (Petrantoni\nand Wollenschläger, n.d.), hasktorch (Huang, Hashimoto, and Stites,\nn.d.) and others." (299:11-310:18, 17588-18374)
├─38 paragraph[1] (312:1-317:46, 18376-18756)
│ └─0 text "This design also allowed us to create first-class C++ bindings and\nmodeling libraries that can be used in places where Python is\ninconvenient, such as the game engine for Starcraft (Synnaeve et al.\n2018) or on mobile platforms. It is even possible to take the Python\ncode describing a PyTorch model and run it without Python using the\nTorchScript engine (The PyTorch team, n.d.b)." (312:1-317:46, 18376-18756)
├─39 heading[1] (319:1-319:34, 18758-18791)
│ │ depth: 2
│ └─0 text "Separate control and data flow" (319:4-319:34, 18761-18791)
├─40 paragraph[1] (321:1-326:29, 18793-19170)
│ └─0 text "PyTorch maintains a strict separation between its control (i.e. program\nbranches, loops) and data flow (i.e. tensors and the operations\nperformed on them). The resolution of the control flow is handled by\nPython and optimized C++ code executed on the host CPU, and result in a\nlinear sequence of operator invocations on the device. Operators can be\nrun either on CPU or on GPU." (321:1-326:29, 18793-19170)
├─41 paragraph[1] (328:1-337:24, 19172-19827)
│ └─0 text "PyTorch is designed to execute operators asynchronously on GPU by\nleveraging the CUDA stream mechanism (Luitjens 2014) to queue CUDA\nkernel invocations to the GPUs hardware FIFO. This allows the system to\noverlap the execution of Python code on CPU with tensor operators on\nGPU. Because the tensor operations usually take a significant amount of\ntime, this lets us saturate the GPU and reach peak performance even in\nan interpreted language with fairly high overhead like Python. Note that\nthis mechanism is nearly invisible to the user. Unless they implement\ntheir own multi-stream primitives all of the CPU-GPU synchronization is\nhandled by the library." (328:1-337:24, 19172-19827)
├─42 paragraph[1] (339:1-342:25, 19829-20054)
│ └─0 text "PyTorch could leverage a similar mechanism to also execute operators\nasynchronously on the CPU. However the costs of cross-thread\ncommunication and synchronization would negate the performance benefit\nof such an optimization." (339:1-342:25, 19829-20054)
├─43 heading[1] (344:1-344:35, 20056-20090)
│ │ depth: 2
│ └─0 text "Custom caching tensor allocator" (344:4-344:35, 20059-20090)
├─44 paragraph[3] (346:1-357:50, 20092-20903)
│ ├─0 text "Almost every operator must dynamically allocate an output tensor to hold\nthe result of its execution. It is therefore critical to optimize the\nspeed of the dynamic memory allocators. PyTorch can rely on optimized\nlibraries (Berger et al. 2000; Evans May 2006; Ghemawat and Menage,\nn.d.) to handle this task on CPU. However, on GPU the " (346:1-350:55, 20092-20427)
│ ├─1 inlineCode "cudaFree" (350:55-350:65, 20427-20437)
│ └─2 text " routine\nmay block its caller until all previously queued work on all GPUs\ncompletes. To avoid this bottleneck, PyTorch implements a custom\nallocator which incrementally builds up a cache of CUDA memory and\nreassigns it to later allocations without further use of CUDA APIs. The\nincremental allocation is also crucial for better interoperability,\nbecause taking up all GPU memory ahead of time would prevent the user\nfrom utilizing other GPU-enabled Python packages." (350:65-357:50, 20437-20903)
├─45 paragraph[1] (359:1-363:14, 20905-21204)
│ └─0 text "To further improve its effectiveness, this allocator was tuned for the\nspecific memory usage patterns of deep learning. For example, it rounds\nup allocations to multiples of 512 bytes to avoid fragmentation issues.\nMoreover, it maintains a distinct pool of memory for every CUDA stream\n(work queue)." (359:1-363:14, 20905-21204)
├─46 paragraph[3] (365:1-373:60, 21206-21825)
│ ├─0 text "The one-pool-per-stream design assumption simplifies the implementation\nand improves the performance of the allocator: because the CPU runs\nahead of the GPU, memory is freed on the CPU " (365:1-367:46, 21206-21391)
│ ├─1 emphasis[1] (367:46-367:54, 21391-21399)
│ │ └─0 text "before" (367:47-367:53, 21392-21398)
│ └─2 text " its last use on\nthe GPU finishes. Since streams serialize execution, if the free\nprecedes the reallocation on the CPU, the same order will occur on the\nGPU. So the allocator can reallocate memory freed on the CPU immediately\nas long as the new allocation is used on the same stream as the freed\nregion. However, if an allocation was last used on one stream and then\nallocated on another, additional synchronization is needed." (367:54-373:60, 21399-21825)
├─47 paragraph[1] (375:1-383:33, 21827-22418)
│ └─0 text "The one-pool-per-stream design seems limiting since the allocations end\nup fragmented per stream, but in practice PyTorch almost never uses\nmultiple streams. It is notoriously hard to write CUDA kernels in a way\nthat would let them cooperatively share the GPU because exact scheduling\nis hardware controlled. In practice, kernel writers usually resort to\nmonolithic kernels that combine multiple tasks. Data loading and\ndistributed computing utilities are exceptions to the one stream design,\nand they carefully insert additional synchronization to avoid bad\ninteractions with the allocator." (375:1-383:33, 21827-22418)
├─48 paragraph[1] (385:1-387:32, 22420-22590)
│ └─0 text "While this design is susceptible to certain corner cases, it almost\nnever exhibits unwanted behaviors in practical code. Most of our users\nare not aware of its existence." (385:1-387:32, 22420-22590)
├─49 heading[1] (389:1-389:19, 22592-22610)
│ │ depth: 2
│ └─0 text "Multiprocessing" (389:4-389:19, 22595-22610)
├─50 paragraph[3] (391:1-396:26, 22612-22985)
│ ├─0 text "Due to the global interpreter lock (GIL) Python's default implementation\ndoes not allow concurrent threads to execute in parallel. To alleviate\nthis problem, the Python community has established a standard\n" (391:1-394:1, 22612-22818)
│ ├─1 inlineCode "multiprocessing" (394:1-394:18, 22818-22835)
│ └─2 text " module, containing a number of utilities that allow\nusers to easily spawn child processes and implement basic inter-process\ncommunication primitives." (394:18-396:26, 22835-22985)
├─51 paragraph[5] (398:1-404:43, 22987-23435)
│ ├─0 text "However, the implementation of the primitives uses the same form of\nserialization used for on-disk persistence, which is inefficient when\ndealing with large arrays. Hence, PyTorch extends the Python\n" (398:1-401:1, 22987-23186)
│ ├─1 inlineCode "multiprocessing" (401:1-401:18, 23186-23203)
│ ├─2 text " module into " (401:18-401:31, 23203-23216)
│ ├─3 inlineCode "torch.multiprocessing" (401:31-401:54, 23216-23239)
│ └─4 text ", which is a\ndrop-in replacement for the built in package and automatically moves the\ndata of tensors sent to other processes to shared memory instead of\nsending it over the communication channel." (401:54-404:43, 23239-23435)
├─52 paragraph[1] (406:1-410:45, 23437-23759)
│ └─0 text "This design greatly improves performance and makes the process isolation\nweaker, resulting in a programming model which more closely resembles\nregular threaded programs. Users can easily implement heavily parallel\nprograms that operate on independent GPUs but later synchronize\ngradients using all-reduce style primitives." (406:1-410:45, 23437-23759)
├─53 paragraph[1] (412:1-414:29, 23761-23929)
│ └─0 text "Another unique feature of this system is that it transparently handles\nsharing of CUDA tensors, making it easy to implement techniques like\nHogwild (Recht et al. 2011)." (412:1-414:29, 23761-23929)
├─54 heading[1] (416:1-416:22, 23931-23952)
│ │ depth: 2
│ └─0 text "Reference counting" (416:4-416:22, 23934-23952)
├─55 paragraph[1] (418:1-421:69, 23954-24236)
│ └─0 text "Users often design their models to utilize all memory available during\ntraining, and increasing batch sizes is a common technique of speeding\nup the process. Therefore, to deliver great performance, PyTorch has to\ntreat memory as a scarce resource that it needs to manage carefully." (418:1-421:69, 23954-24236)
├─56 paragraph[1] (423:1-433:50, 24238-24993)
│ └─0 text "Libraries with eager semantics have to manage tensor memory without\nknowing how it will be used in the future. Garbage collection is the\ntypical way to handle this automatically because it has good amortized\nperformance. In this approach, the runtime periodically investigates the\nstate of the system, enumerates used objects and frees everything else.\nHowever, by deferring the deallocation, it causes the program to use\nmore memory overall (Hertz and Berger 2005). Given the scarcity of GPU\nmemory, these overheads are unacceptable. In fact, Torch7 utilized the\ngarbage collector built into Lua, and a common anti-pattern among the\nusers was to sprinkle the program with explicit triggers to the garbage\ncollector, hoping that the memory errors go away." (423:1-433:50, 24238-24993)
├─57 paragraph[5] (435:1-441:50, 24995-25464)
│ ├─0 text "PyTorch takes a different approach: it relies on a reference counting\nscheme to track the number of uses of each tensor, and frees the\nunderlying memory " (435:1-437:19, 24995-25148)
│ ├─1 emphasis[1] (437:19-437:32, 25148-25161)
│ │ └─0 text "immediately" (437:20-437:31, 25149-25160)
│ ├─2 text " once this count reaches zero. Note that\nPyTorch tracks both references internal to the " (437:32-438:48, 25161-25249)
│ ├─3 inlineCode "libtorch" (438:48-438:58, 25249-25259)
│ └─4 text " library and\nexternal references made by users in their Python code by integrating\nwith Python's own reference counting mechanism. This ensures that memory\nis released exactly when tensors become unneeded." (438:58-441:50, 25259-25464)
├─58 paragraph[1] (443:1-449:69, 25466-25949)
│ └─0 text "One notable caveat is that we can only guarantee the desired performance\ncharacteristics in implementations of languages that either already\nutilize reference counting (CPython, Swift, but not PyPy or many\nscripting languages such as Lua), and those that allow for user-defined\nbehavior for assignment, copies, and moves (e.g. C++, Rust). Bindings to\nimplementations that do not satisfy those criteria will have to\nimplement their own specialized memory management on top of PyTorch." (443:1-449:69, 25466-25949)
├─59 heading[1] (451:1-451:13, 25951-25963)
│ │ depth: 1
│ └─0 text "Evaluation" (451:3-451:13, 25953-25963)
├─60 paragraph[1] (453:1-457:25, 25965-26268)
│ └─0 text "In this section we compare the performance of PyTorch with several other\ncommonly-used deep learning libraries, and find that it achieves\ncompetitive performance across a range of tasks. All experiments were\nperformed on a workstation with two Intel Xeon E5-2698 v4 CPUs and one\nNVIDIA Quadro GP100 GPU." (453:1-457:25, 25965-26268)
├─61 heading[1] (459:1-459:25, 26270-26294)
│ │ depth: 2
│ └─0 text "Asynchronous dataflow" (459:4-459:25, 26273-26294)
├─62 paragraph[1] (461:1-464:27, 26296-26539)
│ └─0 text "We start by quantifying the ability of PyTorch to asynchronously execute\ndataflow on GPU. We use the built-in profiler (The PyTorch team, n.d.a)\nto instrument various benchmarks and record a timeline of the execution\nof a single training step." (461:1-464:27, 26296-26539)
├─63 paragraph[3] (466:1-476:56, 26541-27221)
│ ├─0 text "Figure\n" (466:1-467:1, 26541-26548)
│ ├─1 link[1] (467:1-467:48, 26548-26595)
│ │ │ title: null
│ │ │ url: "#fig:async_execution"
│ │ └─0 text "[fig:async_execution]" (467:2-467:25, 26549-26572)
│ └─2 text "{reference-type=\"ref\"\nreference=\"fig:async_execution\"} shows a representative timeline of\nexecution for the first few operations of a ResNet-50 model. The host\nCPU which queues the work quickly outpaces the execution of the\noperators on the GPU. This allows PyTorch to achieve almost perfect\ndevice utilization. In this example, GPU execution takes around three\ntimes longer than CPU scheduling. The exact ratio depends on the\nrelative performance of the host CPU and the GPU, as well as the number\nof elements in each tensor and the average arithmetic complexity of the\nfloating point computations to be performed on the GPU." (467:48-476:56, 26595-27221)
├─64 paragraph[3] (478:1-480:4, 27223-27297)
│ ├─0 text "::: {.center}\n" (478:1-479:1, 27223-27237)
│ ├─1 image (479:1-479:36, 27237-27272)
│ │ title: null
│ │ url: "async_kernel_launches.pdf"
│ │ alt: "image"
│ └─2 text "{width=\"\\textwidth\"}\n:::" (479:36-480:4, 27272-27297)
├─65 heading[1] (482:1-482:21, 27299-27319)
│ │ depth: 2
│ └─0 text "Memory management" (482:4-482:21, 27302-27319)
├─66 paragraph[7] (484:1-494:70, 27321-28091)
│ ├─0 text "We used the NVIDIA profiler to trace the execution of the CUDA runtime\nas well as the execution of the CUDA kernels launched during one\ntraining iteration of the ResNet-50 model. As shown in\nFigure " (484:1-487:8, 27321-27519)
│ ├─1 link[1] (487:8-487:71, 27519-27582)
│ │ │ title: null
│ │ │ url: "#fig:resnet_annotated_traces"
│ │ └─0 text "[fig:resnet_annotated_traces]" (487:9-487:40, 27520-27551)
│ ├─2 text "{reference-type=\"ref\"\nreference=\"fig:resnet_annotated_traces\"}, the behavior of the first\niteration differs significantly from that of subsequent ones. At first,\ncalls to the CUDA memory management functions (" (487:71-490:48, 27582-27791)
│ ├─3 inlineCode "cudaMalloc" (490:48-490:60, 27791-27803)
│ ├─4 text " and\n" (490:60-491:1, 27803-27808)
│ ├─5 inlineCode "cudaFree" (491:1-491:11, 27808-27818)
│ └─6 text ") slow down the execution quite dramatically by blocking the\nCPU thread for long periods of time, hence lowering the utilization of\nthe GPU. This effect disappears in subsequent iterations as the PyTorch\ncaching memory allocator starts reusing previously allocated regions." (491:11-494:70, 27818-28091)
├─67 paragraph[3] (496:1-498:4, 28093-28171)
│ ├─0 text "::: {.center}\n" (496:1-497:1, 28093-28107)
│ ├─1 image (497:1-497:40, 28107-28146)
│ │ title: null
│ │ url: "resnet50_annotated_traces.pdf"
│ │ alt: "image"
│ └─2 text "{width=\"\\textwidth\"}\n:::" (497:40-498:4, 28146-28171)
├─68 heading[1] (500:1-500:14, 28173-28186)
│ │ depth: 2
│ └─0 text "Benchmarks" (500:4-500:14, 28176-28186)
├─69 paragraph[1] (502:1-506:66, 28188-28528)
│ └─0 text "Finally, we can get an overall sense of single-machine eager mode\nperformance of PyTorch by comparing it to three popular graph-based deep\nlearning frameworks (CNTK, MXNet and TensorFlow), a define-by-run\nframework (Chainer), and production oriented platform (PaddlePaddle).\nThe Appendix details all the steps needed to reproduce our setup." (502:1-506:66, 28188-28528)
├─70 paragraph[3] (508:1-514:11, 28530-28892)
│ ├─0 text "Our results are summarized in\nTable " (508:1-509:7, 28530-28566)
│ ├─1 link[1] (509:7-509:34, 28566-28593)
│ │ │ title: null
│ │ │ url: "#detailed_perf_results"
│ │ └─0 text "1" (509:8-509:9, 28567-28568)
│ └─2 text "{reference-type=\"ref\"\nreference=\"detailed_perf_results\"}. On all the benchmarks, the\nperformance of PyTorch is within 17% of that of of the fastest\nframework. We attribute this result to the fact that these tools offload\nmost of the computation to the same version of the cuDNN and cuBLAS\nlibraries." (509:34-514:11, 28593-28892)
├─71 paragraph[3] (516:1-526:179, 28894-30712)
│ ├─0 text "::: {#detailed_perf_results}\n| | | | | | | |\n|:-------------|:-------------------------------:|:--------------------:|:--------------------:|:---------------------:|:--------------------------:|:--------------------------:|\n| Framework | " (516:1-519:18, 28894-29298)
│ ├─1 emphasis[1] (519:18-519:49, 29298-29329)
│ │ └─0 text "Throughput (higher is better)" (519:19-519:48, 29299-29328)
│ └─2 text " | | | | | |\n| | AlexNet | VGG-19 | ResNet-50 | MobileNet | GNMTv2 | NCF |\n| Chainer | $778 \\pm 15$ | N/A | $\\textbf{219} \\pm 1$ | N/A | N/A | N/A |\n| CNTK | $845 \\pm{8}$ | $84 \\pm{3}$ | $210 \\pm{1}$ | N/A | N/A | N/A |\n| MXNet | $\\textbf{1554} \\pm 22$ | $113 \\pm 1$ | $218 \\pm 2$ | $444 \\pm 2$ | N/A | N/A |\n| PaddlePaddle | $933\\pm{123}$ | $112 \\pm{2}$ | $192 \\pm{4}$ | $\\textbf{557}\\pm{24}$ | N/A | N/A |\n| TensorFlow | $1422 \\pm 27$ | $66 \\pm 2$ | $200 \\pm 1$ | $216 \\pm 15$ | $9631 \\pm 1.3%$ | $4.8e6 \\pm 2.9%$ |\n| PyTorch | $1547 \\pm 316$ | $\\textbf{119} \\pm 1$ | $212 \\pm 2$ | $463 \\pm 17$ | $\\textbf{15512} \\pm 4.8%$ | $\\textbf{5.4e6} \\pm 3.4%$ |" (519:49-526:179, 29329-30712)
├─72 paragraph[1] (528:1-533:4, 30714-31006)
│ └─0 text "Training speed for 6 models using 32bit floats. Throughput is measured\nin images per second for the AlexNet, VGG-19, ResNet-50, and MobileNet\nmodels, in tokens per second for the GNMTv2 model, and in samples per\nsecond for the NCF model. The fastest speed for each model is shown in\nbold.\n:::" (528:1-533:4, 30714-31006)
├─73 heading[1] (535:1-535:12, 31008-31019)
│ │ depth: 2
│ └─0 text "Adoption" (535:4-535:12, 31011-31019)
├─74 paragraph[3] (537:1-548:34, 31021-31811)
│ ├─0 text "The validity of design decisions and their impact on ease-of-use is hard\nto measure. As a proxy, we tried to quantify how well the machine\nlearning community received PyTorch by counting how often various\nmachine learning tools (including Caffe, Chainer, CNTK, Keras, MXNet,\nPyTorch, TensorFlow, and Theano) are mentioned on arXiv e-Prints since\nthe initial release of PyTorch in January 2017. In Figure\n" (537:1-543:1, 31021-31425)
│ ├─1 link[1] (543:1-543:54, 31425-31478)
│ │ │ title: null
│ │ │ url: "#fig:pytorch_references"
│ │ └─0 text "[fig:pytorch_references]" (543:2-543:28, 31426-31452)
│ └─2 text "{reference-type=\"ref\"\nreference=\"fig:pytorch_references\"} we report the monthly number of\nmentions of the word \"PyTorch\" as a percentage of all mentions among\nthese deep learning frameworks. We counted tools mentioned multiple\ntimes in a given paper only once, and made the search case insensitive\nto account for various spellings." (543:54-548:34, 31478-31811)
├─75 paragraph[3] (550:1-552:4, 31813-31880)
│ ├─0 text "::: {.center}\n" (550:1-551:1, 31813-31827)
│ ├─1 image (551:1-551:29, 31827-31855)
│ │ title: null
│ │ url: "arxiv_mentions.pdf"
│ │ alt: "image"
│ └─2 text "{width=\"\\linewidth\"}\n:::" (551:29-552:4, 31855-31880)
├─76 heading[1] (554:1-554:29, 31882-31910)
│ │ depth: 1
│ └─0 text "Conclusion and future work" (554:3-554:29, 31884-31910)
├─77 paragraph[1] (556:1-566:17, 31912-32615)
│ └─0 text "PyTorch has become a popular tool in the deep learning research\ncommunity by combining a focus on usability with careful performance\nconsiderations. In addition to continuing to support the latest trends\nand advances in deep learning, in the future we plan to continue to\nimprove the speed and scalability of PyTorch. Most notably, we are\nworking on the PyTorch JIT: a suite of tools that allow PyTorch programs\nto be executed outside of the Python interpreter where they can be\nfurther optimized. We also intend to improve support for distributed\ncomputation by providing efficient primitives for data parallelism as\nwell as a Pythonic library for model parallelism based around remote\nprocedure calls." (556:1-566:17, 31912-32615)
├─78 heading[1] (568:1-568:19, 32617-32635)
│ │ depth: 1
│ └─0 text "Acknowledgements" (568:3-568:19, 32619-32635)
├─79 paragraph[1] (570:1-587:57, 32637-33864)
│ └─0 text "We are grateful to the PyTorch community for their feedback and\ncontributions that greatly influenced the design and implementation of\nPyTorch. We thank all the PyTorch core team members, contributors and\npackage maintainers including Ailing Zhang, Alex Suhan, Alfredo Mendoza,\nAlican Bozkurt, Andrew Tulloch, Ansha Yu, Anthony Shoumikhin, Bram\nWasti, Brian Vaughan, Christian Puhrsch, David Reiss, David Riazati,\nDavide Libenzi, Dmytro Dzhulgakov, Dwaraj Rajagopal, Edward Yang, Elias\nEllison, Fritz Obermeyer, George Zhang, Hao Lu, Hong Xu, Hung Duong,\nIgor Fedan, Ilia Cherniavskii, Iurii Zdebskyi, Ivan Kobzarev, James\nReed, Jeff Smith, Jerry Chen, Jerry Zhang, Jiakai Liu, Johannes M.\nDieterich, Karl Ostmo, Lin Qiao, Martin Yuan, Michael Suo, Mike Ruberry,\nMikhail Zolothukhin, Mingzhe Li, Neeraj Pradhan, Nick Korovaiko, Owen\nAnderson, Pavel Belevich, Peter Johnson, Pritam Damania, Raghuraman\nKrishnamoorthi, Richard Zou, Roy Li, Rui Zhu, Sebastian Messmer, Shen\nLi, Simon Wang, Supriya Rao, Tao Xu, Thomas Viehmann, Vincent\nQuenneville-Belair, Vishwak Srinivasan, Vitaly Fedyunin, Wanchao Liang,\nWei Yang, Will Feng, Xiaomeng Yang, Xiaoqiang Zheng, Xintao Chen,\nYangqing Jia, Yanli Zhao, Yinghai Lu and Zafar Takhirov." (570:1-587:57, 32637-33864)
├─80 paragraph[3] (589:1-595:4, 33866-34165)
│ ├─0 text "::: {#refs .references .csl-bib-body .hanging-indent}\n::: {#ref-TF .csl-entry}\nAbadi, Martı́n, Ashish Agarwal, Paul Barham, Eugene Brevdo, Zhifeng Chen,\nCraig Citro, Greg S. Corrado, et al. 2015. \"TensorFlow: Large-Scale\nMachine Learning on Heterogeneous Systems.\"\n" (589:1-594:1, 33866-34131)
│ ├─1 link[1] (594:1-594:30, 34131-34160)
│ │ │ title: null
│ │ │ url: "https://www.tensorflow.org/"
│ │ └─0 text "https://www.tensorflow.org/" (594:2-594:29, 34132-34159)
│ └─2 text ".\n:::" (594:30-595:4, 34160-34165)
├─81 paragraph[1] (597:1-600:4, 34167-34271)
│ └─0 text "::: {#ref-APL .csl-entry}\nAbrams, Philip S. 1970. \"An APL Machine.\" PhD thesis, Stanford\nUniversity.\n:::" (597:1-600:4, 34167-34271)
├─82 paragraph[5] (602:1-605:4, 34273-34402)
│ ├─0 text "::: {#ref-jax .csl-entry}\nal., Matthew Johnson et. 2018. \"Jax.\" " (602:1-603:39, 34273-34337)
│ ├─1 emphasis[1] (603:39-603:58, 34337-34356)
│ │ └─0 text "GitHub Repository" (603:40-603:57, 34338-34355)
│ ├─2 text ".\n" (603:58-604:1, 34356-34358)
│ ├─3 link[1] (604:1-604:32, 34358-34389)
│ │ │ title: null
│ │ │ url: "https://github.com/google/jax"
│ │ └─0 text "https://github.com/google/jax" (604:2-604:31, 34359-34388)
│ └─4 text "; GitHub.\n:::" (604:32-605:4, 34389-34402)
├─83 paragraph[5] (607:1-610:4, 34404-34537)
│ ├─0 text "::: {#ref-flux .csl-entry}\nal., Mike Innes et. 2018. \"Flux.jl.\" " (607:1-608:38, 34404-34468)
│ ├─1 emphasis[1] (608:38-608:57, 34468-34487)
│ │ └─0 text "GitHub Repository" (608:39-608:56, 34469-34486)
│ ├─2 text ".\n" (608:57-609:1, 34487-34489)
│ ├─3 link[1] (609:1-609:36, 34489-34524)
│ │ │ title: null
│ │ │ url: "https://github.com/FluxML/Flux.jl"
│ │ └─0 text "https://github.com/FluxML/Flux.jl" (609:2-609:35, 34490-34523)
│ └─4 text "; GitHub.\n:::" (609:36-610:4, 34524-34537)
├─84 paragraph[5] (612:1-617:4, 34539-34837)
│ ├─0 text "::: {#ref-autodiff_survey .csl-entry}\nBaydin, Atilim Gunes, Barak A. Pearlmutter, Alexey Andreyevich Radul,\nand Jeffrey Mark Siskind. 2017. \"Automatic Differentiation in Machine\nLearning: A Survey.\" " (612:1-615:22, 34539-34738)
│ ├─1 emphasis[1] (615:22-615:44, 34738-34760)
│ │ └─0 text "J. Mach. Learn. Res." (615:23-615:43, 34739-34759)
│ ├─2 text " 18 (1): 5595--5637.\n" (615:44-616:1, 34760-34781)
│ ├─3 link[1] (616:1-616:52, 34781-34832)
│ │ │ title: null
│ │ │ url: "http://dl.acm.org/citation.cfm?id=3122009.3242010"
│ │ └─0 text "http://dl.acm.org/citation.cfm?id=3122009.3242010" (616:2-616:51, 34782-34831)
│ └─4 text ".\n:::" (616:52-617:4, 34832-34837)
├─85 paragraph[5] (619:1-626:4, 34839-35237)
│ ├─0 text "::: {#ref-hoard .csl-entry}\nBerger, Emery D., Kathryn S. McKinley, Robert D. Blumofe, and Paul R.\nWilson. 2000. \"Hoard: A Scalable Memory Allocator for Multithreaded\nApplications.\" In " (619:1-622:19, 34839-35023)
│ ├─1 emphasis[1] (622:19-623:71, 35023-35147)
│ │ └─0 text "Proceedings of the Ninth International Conference on\nArchitectural Support for Programming Languages and Operating Systems" (622:20-623:70, 35024-35146)
│ ├─2 text ",\n117--28. ASPLOS IX. New York, NY, USA: ACM.\n" (623:71-625:1, 35147-35193)
│ ├─3 link[1] (625:1-625:40, 35193-35232)
│ │ │ title: null
│ │ │ url: "https://doi.org/10.1145/378993.379232"
│ │ └─0 text "https://doi.org/10.1145/378993.379232" (625:2-625:39, 35194-35231)
│ └─4 text ".\n:::" (625:40-626:4, 35232-35237)
├─86 paragraph[5] (628:1-632:4, 35239-35459)
│ ├─0 text "::: {#ref-Julia .csl-entry}\nBezanson, Jeff, Alan Edelman, Stefan Karpinski, and Viral B Shah. 2017.\n\"Julia: A Fresh Approach to Numerical Computing.\" " (628:1-630:51, 35239-35389)
│ ├─1 emphasis[1] (630:51-630:64, 35389-35402)
│ │ └─0 text "SIAM Review" (630:52-630:63, 35390-35401)
│ ├─2 text " 59 (1):\n65--98. " (630:64-631:9, 35402-35419)
│ ├─3 link[1] (631:9-631:44, 35419-35454)
│ │ │ title: null
│ │ │ url: "https://doi.org/10.1137/141000671"
│ │ └─0 text "https://doi.org/10.1137/141000671" (631:10-631:43, 35420-35453)
│ └─4 text ".\n:::" (631:44-632:4, 35454-35459)
├─87 paragraph[3] (634:1-638:4, 35461-35691)
│ ├─0 text "::: {#ref-cudnn .csl-entry}\nChetlur, Sharan, Cliff Woolley, Philippe Vandermersch, Jonathan D.\nCohen, John Tran, Bryan Catanzaro, and Evan Shelhamer. 2014. \"cuDNN:\nEfficient Primitives for Deep Learning.\" " (634:1-637:42, 35461-35666)
│ ├─1 emphasis[1] (637:42-637:48, 35666-35672)
│ │ └─0 text "CoRR" (637:43-637:47, 35667-35671)
│ └─2 text " abs/1410.0759.\n:::" (637:48-638:4, 35672-35691)
├─88 paragraph[1] (640:1-643:4, 35693-35844)
│ └─0 text "::: {#ref-Torch .csl-entry}\nCollobert, Ronan, Samy Bengio, and Johnny Mariéthoz. 2002. \"Torch: A\nModular Machine Learning Software Library.\" Idiap.\n:::" (640:1-643:4, 35693-35844)
├─89 paragraph[3] (645:1-648:4, 35846-36016)
│ ├─0 text "::: {#ref-Torch7 .csl-entry}\nCollobert, Ronan, Koray Kavukcuoglu, and Clément Farabet. 2011. \"Torch7:\nA Matlab-Like Environment for Machine Learning.\" In " (645:1-647:53, 35846-36000)
│ ├─1 emphasis[1] (647:53-647:64, 36000-36011)
│ │ └─0 text "NIPS 2011" (647:54-647:63, 36001-36010)
│ └─2 text ".\n:::" (647:64-648:4, 36011-36016)
├─90 paragraph[1] (650:1-652:4, 36018-36104)
│ └─0 text "::: {#ref-dlpack .csl-entry}\nDMLC. n.d. \"DLPack: Open in Memory Tensor Structure.\"\n:::" (650:1-652:4, 36018-36104)
├─91 paragraph[5] (654:1-659:4, 36106-36357)
│ ├─0 text "::: {#ref-jemalloc .csl-entry}\nEvans, J. May 2006. \"A Scalable Concurrent Malloc(3) Implementation for\nFreeBSD.\" In " (654:1-656:14, 36106-36222)
│ ├─1 emphasis[1] (656:14-656:58, 36222-36266)
│ │ └─0 text "In BSDCan --- the Technical BSD Conference" (656:15-656:57, 36223-36265)
│ ├─2 text ". Ottawa,\nCanada.\n" (656:58-658:1, 36266-36284)
│ ├─3 link[1] (658:1-658:69, 36284-36352)
│ │ │ title: null
│ │ │ url: "http://people.freebsd.org/˜jasone/jemalloc/bsdcan2006/jemalloc.pdf"
│ │ └─0 text "http://people.freebsd.org/˜jasone/jemalloc/bsdcan2006/jemalloc.pdf" (658:2-658:68, 36285-36351)
│ └─4 text ".\n:::" (658:69-659:4, 36352-36357)
├─92 paragraph[1] (661:1-663:4, 36359-36454)
│ └─0 text "::: {#ref-worse_is_better .csl-entry}\nGabriel, Richard. n.d. \"The Rise of Worse Is Better.\"\n:::" (661:1-663:4, 36359-36454)
├─93 paragraph[3] (665:1-668:4, 36456-36618)
│ ├─0 text "::: {#ref-tcmalloc .csl-entry}\nGhemawat, S., and P. Menage. n.d. \"Tcmalloc: Thread-Caching Malloc.\"\n" (665:1-667:1, 36456-36556)
│ ├─1 link[1] (667:1-667:58, 36556-36613)
│ │ │ title: null
│ │ │ url: "http://goog-perftools.sourceforge.net/doc/tcmalloc.html"
│ │ └─0 text "http://goog-perftools.sourceforge.net/doc/tcmalloc.html" (667:2-667:57, 36557-36612)
│ └─2 text ".\n:::" (667:58-668:4, 36613-36618)
├─94 paragraph[1] (670:1-673:4, 36620-36739)
│ └─0 text "::: {#ref-eigenweb .csl-entry}\nGuennebaud, Gaël, Benoît Jacob, et al. 2010. \"Eigen V3.\"\nhttp://eigen.tuxfamily.org.\n:::" (670:1-673:4, 36620-36739)
├─95 paragraph[5] (675:1-681:4, 36741-37129)
│ ├─0 text "::: {#ref-garbage_collection .csl-entry}\nHertz, Matthew, and Emery D. Berger. 2005. \"Quantifying the Performance\nof Garbage Collection Vs. Explicit Memory Management.\" In " (675:1-677:59, 36741-36912)
│ ├─1 emphasis[1] (677:59-679:51, 36912-37036)
│ │ └─0 text "Proceedings\nof the 20th Annual ACM SIGPLAN Conference on Object-Oriented\nProgramming, Systems, Languages, and Applications" (677:60-679:50, 36913-37035)
│ ├─2 text ", 313--26. OOPSLA '05.\nNew York, NY, USA: ACM. " (679:51-680:25, 37036-37083)
│ ├─3 link[1] (680:25-680:66, 37083-37124)
│ │ │ title: null
│ │ │ url: "https://doi.org/10.1145/1094811.1094836"
│ │ └─0 text "https://doi.org/10.1145/1094811.1094836" (680:26-680:65, 37084-37123)
│ └─4 text ".\n:::" (680:66-681:4, 37124-37129)
├─96 paragraph[1] (683:1-685:4, 37131-37232)
│ └─0 text "::: {#ref-hasktorch .csl-entry}\nHuang, Austin, Junji Hashimoto, and Sam Stites. n.d. \"HaskTorch.\"\n:::" (683:1-685:4, 37131-37232)
├─97 paragraph[3] (687:1-692:4, 37234-37525)
│ ├─0 text "::: {#ref-Caffe .csl-entry}\n\"Jia, Yangqing, Evan Shelhamer, Jeff Donahue, Sergey Karayev, Jonathan\nLong, Ross Girshick, Sergio Guadarrama, and Trevor\" Darrell. \"2014\".\n\"\"Caffe: Convolutional Architecture for Fast Feature Embedding\".\"\n" (687:1-691:1, 37234-37474)
│ ├─1 emphasis[1] (691:1-691:37, 37474-37510)
│ │ └─0 text "\"arXiv Preprint arXiv:1408.5093\"" (691:2-691:36, 37475-37509)
│ └─2 text ", \"2014\".\n:::" (691:37-692:4, 37510-37525)
├─98 paragraph[1] (694:1-697:4, 37527-37670)
│ └─0 text "::: {#ref-SciPy .csl-entry}\nJones, Eric, Travis Oliphant, Pearu Peterson, et al. 2001--. \"SciPy:\nOpen Source Scientific Tools for Python.\"\n:::" (694:1-697:4, 37527-37670)
├─99 paragraph[1] (699:1-702:4, 37672-37804)
│ └─0 text "::: {#ref-maxdnn .csl-entry}\nLavin, Andrew. 2015. \"maxDNN: An Efficient Convolution Kernel for Deep\nLearning with Maxwell GPUs.\"\n:::" (699:1-702:4, 37672-37804)
├─100 paragraph[3] (704:1-708:4, 37806-38014)
│ ├─0 text "::: {#ref-fast_cnn .csl-entry}\nLavin, Andrew, and Scott Gray. 2016. \"Fast Algorithms for Convolutional\nNeural Networks.\" " (704:1-706:19, 37806-37927)
│ ├─1 emphasis[1] (706:19-707:20, 37927-37999)
│ │ └─0 text "2016 IEEE Conference on Computer Vision and Pattern\nRecognition (CVPR)" (706:20-707:19, 37928-37998)
│ └─2 text ", 4013--21.\n:::" (707:20-708:4, 37999-38014)
├─101 paragraph[3] (710:1-714:4, 38016-38193)
│ ├─0 text "::: {#ref-mnist .csl-entry}\nLeCun, Yann, and Corinna Cortes. n.d. \"MNIST Handwritten Digit\nDatabase.\" http://yann.lecun.com/exdb/mnist/.\n" (710:1-713:1, 38016-38153)
│ ├─1 link[1] (713:1-713:36, 38153-38188)
│ │ │ title: null
│ │ │ url: "http://yann.lecun.com/exdb/mnist/"
│ │ └─0 text "http://yann.lecun.com/exdb/mnist/" (713:2-713:35, 38154-38187)
│ └─2 text ".\n:::" (713:36-714:4, 38188-38193)
├─102 paragraph[1] (716:1-719:4, 38195-38327)
│ └─0 text "::: {#ref-Lush .csl-entry}\nLeCun, Y, and L Bottou. 2002. \"Lush Reference Manual.\" code available at\nhttp://lush.sourceforge.net.\n:::" (716:1-719:4, 38195-38327)
├─103 paragraph[5] (721:1-727:4, 38329-38691)
│ ├─0 text "::: {#ref-Leuck-dual-numbers .csl-entry}\nLeuck, Holger, and Hans-Hellmut Nagel. 1999. \"Automatic Differentiation\nFacilitates OF-Integration into Steering-Angle-Based Road Vehicle\nTracking.\" In " (721:1-724:15, 38329-38522)
│ ├─1 emphasis[1] (724:15-725:63, 38522-38632)
│ │ └─0 text "1999 Conference on Computer Vision and Pattern\nRecognition (CVPR '99), 23-25 June 1999, Ft. Collins, CO, USA" (724:16-725:62, 38523-38631)
│ ├─2 text ",\n2360--65. " (725:63-726:11, 38632-38644)
│ ├─3 link[1] (726:11-726:53, 38644-38686)
│ │ │ title: null
│ │ │ url: "https://doi.org/10.1109/CVPR.1999.784659"
│ │ └─0 text "https://doi.org/10.1109/CVPR.1999.784659" (726:12-726:52, 38645-38685)
│ └─4 text ".\n:::" (726:53-727:4, 38686-38691)
├─104 paragraph[3] (729:1-732:4, 38693-38883)
│ ├─0 text "::: {#ref-cuda_stream .csl-entry}\nLuitjens, Justin. 2014. \"CUDA Streams.\"\n" (729:1-731:1, 38693-38767)
│ ├─1 link[1] (731:1-731:112, 38767-38878)
│ │ │ title: null
│ │ │ url: "http://on-demand.gputechconf.com/gtc/2014/presentations/S4158-cuda-streams-best-practices-common-pitfalls.pdf"
│ │ └─0 text "http://on-demand.gputechconf.com/gtc/2014/presentations/S4158-cuda-streams-best-practices-common-pitfalls.pdf" (731:2-731:111, 38768-38877)
│ └─2 text ".\n:::" (731:112-732:4, 38878-38883)
├─105 paragraph[1] (734:1-737:4, 38885-39066)
│ └─0 text "::: {#ref-maclaurin2016phd .csl-entry}\nMaclaurin, Dougal. 2016. \"Modeling, Inference and Optimization with\nComposable Differentiable Procedures.\" PhD thesis, Harvard University.\n:::" (734:1-737:4, 38885-39066)
├─106 paragraph[3] (739:1-742:4, 39068-39196)
│ ├─0 text "::: {#ref-Matlab .csl-entry}\n" (739:1-740:1, 39068-39097)
│ ├─1 emphasis[1] (740:1-740:32, 39097-39128)
│ │ └─0 text "MATLAB and Statistics Toolbox" (740:2-740:31, 39098-39127)
│ └─2 text ". n.d. Natick, Massachusetts, United\nStates: The MathWorks, Inc.\n:::" (740:32-742:4, 39128-39196)
├─107 paragraph[3] (744:1-748:4, 39198-39371)
│ ├─0 text "::: {#ref-Pandas .csl-entry}\nMcKinney, Wes. 2010. \"Data Structures for Statistical Computing in\nPython.\" In " (744:1-746:13, 39198-39306)
│ ├─1 emphasis[1] (746:13-747:7, 39306-39366)
│ │ └─0 text "Proceedings of the 9th Python in Science Conference,\n51-56" (746:14-747:6, 39307-39365)
│ └─2 text ".\n:::" (747:7-748:4, 39366-39371)
├─108 paragraph[5] (750:1-755:4, 39373-39617)
│ ├─0 text "::: {#ref-DyNet .csl-entry}\nNeubig, G., C. Dyer, Y. Goldberg, A. Matthews, W. Ammar, A.\nAnastasopoulos, M. Ballesteros, et al. 2017. \"DyNet: The Dynamic Neural\nNetwork Toolkit.\" " (750:1-753:19, 39373-39551)
│ ├─1 emphasis[1] (753:19-753:35, 39551-39567)
│ │ └─0 text "ArXiv e-Prints" (753:20-753:34, 39552-39566)
│ ├─2 text ", January.\n" (753:35-754:1, 39567-39578)
│ ├─3 link[1] (754:1-754:35, 39578-39612)
│ │ │ title: null
│ │ │ url: "https://arxiv.org/abs/1701.03980"
│ │ └─0 text "https://arxiv.org/abs/1701.03980" (754:2-754:34, 39579-39611)
│ └─4 text ".\n:::" (754:35-755:4, 39612-39617)
├─109 paragraph[1] (757:1-760:4, 39619-39726)
│ └─0 text "::: {#ref-Numpy .csl-entry}\nOliphant, Travis. 2006. \"NumPy: A Guide to NumPy.\" USA: Trelgol\nPublishing.\n:::" (757:1-760:4, 39619-39726)
├─110 paragraph[3] (762:1-766:4, 39728-39982)
│ ├─0 text "::: {#ref-pytorch_autodiff .csl-entry}\nPaszke, Adam, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang,\nZachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam\nLerer. 2017. \"Automatic Differentiation in PyTorch.\" In " (762:1-765:57, 39728-39962)
│ ├─1 emphasis[1] (765:57-765:72, 39962-39977)
│ │ └─0 text "NIPS Workshop" (765:58-765:71, 39963-39976)
│ └─2 text ".\n:::" (765:72-766:4, 39977-39982)
├─111 paragraph[1] (768:1-770:4, 39984-40082)
│ └─0 text "::: {#ref-nimtorch .csl-entry}\nPetrantoni, Giovanni, and Jörg Wollenschläger. n.d. \"NimTorch.\"\n:::" (768:1-770:4, 39984-40082)
├─112 paragraph[5] (772:1-776:4, 40084-40310)
│ ├─0 text "::: {#ref-Piponi-dual-numbers .csl-entry}\nPiponi, Dan. 2004. \"Automatic Differentiation, C++ Templates, and\nPhotogrammetry.\" " (772:1-774:18, 40084-40209)
│ ├─1 emphasis[1] (774:18-774:50, 40209-40241)
│ │ └─0 text "J. Graphics, GPU, & Game Tools" (774:19-774:49, 40210-40240)
│ ├─2 text " 9 (4): 41--55.\n" (774:50-775:1, 40241-40257)
│ ├─3 link[1] (775:1-775:49, 40257-40305)
│ │ │ title: null
│ │ │ url: "https://doi.org/10.1080/10867651.2004.10504901"
│ │ └─0 text "https://doi.org/10.1080/10867651.2004.10504901" (775:2-775:48, 40258-40304)
│ └─4 text ".\n:::" (775:49-776:4, 40305-40310)
├─113 paragraph[5] (778:1-782:4, 40312-40502)
│ ├─0 text "::: {#ref-R .csl-entry}\nR Core Team. n.d. " (778:1-779:19, 40312-40354)
│ ├─1 emphasis[1] (779:19-780:11, 40354-40411)
│ │ └─0 text "R: A Language and Environment for Statistical\nComputing" (779:20-780:10, 40355-40410)
│ ├─2 text ". Vienna, Austria: R Foundation for Statistical Computing.\n" (780:11-781:1, 40411-40470)
│ ├─3 link[1] (781:1-781:28, 40470-40497)
│ │ │ title: null
│ │ │ url: "http://www.R-project.org/"
│ │ └─0 text "http://www.R-project.org/" (781:2-781:27, 40471-40496)
│ └─4 text ".\n:::" (781:28-782:4, 40497-40502)
├─114 paragraph[5] (784:1-792:4, 40504-41004)
│ ├─0 text "::: {#ref-Hogwild .csl-entry}\nRecht, Benjamin, Christopher Ré, Stephen J. Wright, and Feng Niu. 2011.\n\"Hogwild: A Lock-Free Approach to Parallelizing Stochastic Gradient\nDescent.\" In " (784:1-787:14, 40504-40687)
│ ├─1 emphasis[1] (787:14-789:68, 40687-40879)
│ │ └─0 text "Advances in Neural Information Processing Systems 24: 25th\nAnnual Conference on Neural Information Processing Systems 2011.\nProceedings of a Meeting Held 12-14 December 2011, Granada, Spain." (787:15-789:67, 40688-40878)
│ ├─2 text ",\n693--701.\n" (789:68-791:1, 40879-40891)
│ ├─3 link[1] (791:1-791:109, 40891-40999)
│ │ │ title: null
│ │ │ url: "http://papers.nips.cc/paper/4390-hogwild-a-lock-free-approach-to-parallelizing-stochastic-gradient-descent"
│ │ └─0 text "http://papers.nips.cc/paper/4390-hogwild-a-lock-free-approach-to-parallelizing-stochastic-gradient-descent" (791:2-791:108, 40892-40998)
│ └─4 text ".\n:::" (791:109-792:4, 40999-41004)
├─115 paragraph[5] (794:1-800:4, 41006-41320)
│ ├─0 text "::: {#ref-CNTK .csl-entry}\nSeide, Frank, and Amit Agarwal. 2016. \"CNTK: Microsoft's Open-Source\nDeep-Learning Toolkit.\" In " (794:1-796:28, 41006-41129)
│ ├─1 emphasis[1] (796:28-797:65, 41129-41229)
│ │ └─0 text "Proceedings of the 22Nd ACM SIGKDD\nInternational Conference on Knowledge Discovery and Data Mining" (796:29-797:64, 41130-41228)
│ ├─2 text ",\n2135--35. KDD '16. New York, NY, USA: ACM.\n" (797:65-799:1, 41229-41274)
│ ├─3 link[1] (799:1-799:42, 41274-41315)
│ │ │ title: null
│ │ │ url: "https://doi.org/10.1145/2939672.2945397"
│ │ └─0 text "https://doi.org/10.1145/2939672.2945397" (799:2-799:41, 41275-41314)
│ └─4 text ".\n:::" (799:42-800:4, 41315-41320)
├─116 paragraph[3] (802:1-807:4, 41322-41566)
│ ├─0 text "::: {#ref-EBLearn .csl-entry}\nSermanet, Pierre, Koray Kavukcuoglu, and Yann LeCun. 2009. \"Eblearn:\nOpen-Source Energy-Based Learning in c++.\" In " (802:1-804:47, 41322-41467)
│ ├─1 emphasis[1] (804:47-805:64, 41467-41546)
│ │ └─0 text "2009 21st IEEE\nInternational Conference on Tools with Artificial Intelligence" (804:48-805:63, 41468-41545)
│ └─2 text ",\n693--97. IEEE.\n:::" (805:64-807:4, 41546-41566)
├─117 paragraph[3] (809:1-814:4, 41568-41859)
│ ├─0 text "::: {#ref-starcraft_pytorch .csl-entry}\nSynnaeve, G., Z. Lin, J. Gehring, D. Gant, V. Mella, V. Khalidov, N.\nCarion, and N. Usunier. 2018. \"Forward Modeling for Partial Observation\nStrategy Games - a Starcraft Defogger.\" In " (809:1-812:44, 41568-41792)
│ ├─1 emphasis[1] (812:44-813:32, 41792-41843)
│ │ └─0 text "Advances in Neural\nInformation Processing Systems" (812:45-813:31, 41793-41842)
│ └─2 text ", 10761--71.\n:::" (813:32-814:4, 41843-41859)
├─118 paragraph[1] (816:1-818:4, 41861-41959)
│ └─0 text "::: {#ref-python_gil .csl-entry}\nteam, The Python. n.d. \"The CPython Global Interpreter Lock.\"\n:::" (816:1-818:4, 41861-41959)
├─119 paragraph[3] (820:1-822:4, 41961-42059)
│ ├─0 text "::: {#ref-autograd_profiler .csl-entry}\nteam, The PyTorch. n.d.a. " (820:1-821:27, 41961-42027)
│ ├─1 emphasis[1] (821:27-821:54, 42027-42054)
│ │ └─0 text "Pytorch Autograd Profiler" (821:28-821:53, 42028-42053)
│ └─2 text ".\n:::" (821:54-822:4, 42054-42059)
├─120 paragraph[3] (824:1-826:4, 42061-42132)
│ ├─0 text "::: {#ref-torchscript .csl-entry}\n---------. n.d.b. " (824:1-825:19, 42061-42113)
│ ├─1 emphasis[1] (825:19-825:33, 42113-42127)
│ │ └─0 text "Torch Script" (825:20-825:32, 42114-42126)
│ └─2 text ".\n:::" (825:33-826:4, 42127-42132)
├─121 paragraph[5] (828:1-832:4, 42134-42361)
│ ├─0 text "::: {#ref-Theano .csl-entry}\nTheano Development Team. 2016. \"[Theano: A Python framework for fast\ncomputation of mathematical expressions]{.nocase}.\" " (828:1-830:53, 42134-42284)
│ ├─1 emphasis[1] (830:53-830:69, 42284-42300)
│ │ └─0 text "arXiv e-Prints" (830:54-830:68, 42285-42299)
│ ├─2 text "\nabs/1605.02688 (May). " (830:69-831:23, 42300-42323)
│ ├─3 link[1] (831:23-831:56, 42323-42356)
│ │ │ title: null
│ │ │ url: "http://arxiv.org/abs/1605.02688"
│ │ └─0 text "http://arxiv.org/abs/1605.02688" (831:24-831:55, 42324-42355)
│ └─4 text ".\n:::" (831:56-832:4, 42356-42361)
├─122 paragraph[5] (834:1-841:4, 42363-42752)
│ ├─0 text "::: {#ref-Chainer .csl-entry}\nTokui, Seiya, Kenta Oono, Shohei Hido, and Justin Clayton. 2015.\n\"Chainer: A Next-Generation Open Source Framework for Deep Learning.\" In\n" (834:1-837:1, 42363-42531)
│ ├─1 emphasis[1] (837:1-839:16, 42531-42684)
│ │ └─0 text "Proceedings of Workshop on Machine Learning Systems (LearningSys) in\nthe Twenty-Ninth Annual Conference on Neural Information Processing\nSystems (NIPS)" (837:2-839:15, 42532-42683)
│ ├─2 text ".\n" (839:16-840:1, 42684-42686)
│ ├─3 link[1] (840:1-840:62, 42686-42747)
│ │ │ title: null
│ │ │ url: "http://learningsys.org/papers/LearningSys_2015_paper_33.pdf"
│ │ └─0 text "http://learningsys.org/papers/LearningSys_2015_paper_33.pdf" (840:2-840:61, 42687-42746)
│ └─4 text ".\n:::" (840:62-841:4, 42747-42752)
└─123 paragraph[5] (843:1-849:4, 42754-43047)
├─0 text "::: {#ref-starcraft2 .csl-entry}\nVinyals, Oriol, Timo Ewalds, Sergey Bartunov, Petko Georgiev, Alexander\nSasha Vezhnevets, Michelle Yeo, Alireza Makhzani, et al. 2017.\n\"StarCraft II: A New Challenge for Reinforcement Learning.\" " (843:1-846:61, 42754-42982)
├─1 emphasis[1] (846:61-846:67, 42982-42988)
│ └─0 text "CoRR" (846:62-846:66, 42983-42987)
├─2 text "\nabs/1708.04782. " (846:67-847:17, 42988-43005)
├─3 link[1] (847:17-847:50, 43005-43038)
│ │ title: null
│ │ url: "http://arxiv.org/abs/1708.04782"
│ └─0 text "http://arxiv.org/abs/1708.04782" (847:18-847:49, 43006-43037)
└─4 text ".\n:::\n:::" (847:50-849:4, 43038-43047)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment