Skip to content

Instantly share code, notes, and snippets.

@eggsyntax
Created November 15, 2023 20:43
Show Gist options
  • Save eggsyntax/9c597a29c53bb653b38c54fdf4299c45 to your computer and use it in GitHub Desktop.
Save eggsyntax/9c597a29c53bb653b38c54fdf4299c45 to your computer and use it in GitHub Desktop.

Notes

  • Point of confusion [resolved via email]: it's not clear to me that 'gradients of inputs wrt outputs' is actually new information on top of 'the model's outputs for all inputs'. Maybe I'm thinking too much in terms of LMs though?
  • If we think about it in terms of something more continuous like image classification, the gradients are valuable in that they provide information about what parts of the input are most important, in the sense of "what inputs would, if tweaked, have the largest impact on the output" (for a specific case).
  • In the limit, we can discover (and 'describe') everything about the model, eg by creating an enormous lookup table (by iterating over every possible input and recording the output produced, possibly with some additional complexity from tracking any additional internal state that the model has). This obviously isn't especially helpful for human-level understanding of a model's behavior, and would take an infeasible amount of time to create for any model large enough to be of interest. But it's interesting to notice that it's in principle not necessary to have any access to the model's internals.
    • This does assume that the model is not being modified at inference time in response to inputs.
  • It would be nice to be able to simplify that into a description of the boundaries of the input that would produce a particular output, but that might be very difficult since those boundaries aren't necessarily continuous. Could presumably brute-force it, though not in practice.
  • VERY quick glance at the literature looking for existing model-agnostic methods finds
    • Papers and posts I looked at:
    • Local methods
      • LIME - train a simpler surrogate model on the training set produced by perturbing the model in the neighborhood of a particular input
      • Anchor - similar to LIME but creates a list of rules with high likelihood of correctness locally
      • Prototype model - give user predicted class score, most similar prototype, six most important features
      • Decision Boundary - description a bit confusing, but: find nearby inputs that change the output
      • Individual conditional expectation curves...seem similar to decision boundaries?
      • Shapley values of each feature
    • Global methods
      • Partial dependence plot, accumulated local effect plots: average importance of a feature averaged out across cases
      • Feature interaction - seems like it looks only at interactions of small numbers of features?
      • Functional decomposition - seems similar to feature interaction
      • Permutation feature importance -- seems v relevant to the underlying question: calculate the loss increase after permuting the feature
      • Global surrogate models (~= Cynthia Rubin's work IIUC)
      • Prototypes & criticisms (ie 'representative' vs 'non-representative'). Seems like it maybe assumes access to the training set, which I'm reluctant to assume given the problem statement.
        • though I see an interesting point mentioned, that you could use clustering & assume that a cluster that returns actual data points as cluster centers is a useful prototype. Prototypes != features obv, but maybe there's something there.
  • Intuitively it seems like there's a good approach here where you could discover what the model thinks interactively, eg you have an interactive loop such that you can start from a question ("when does the model think an image is of a cat?") and examine the behavior with respect to that specific question. Seems more tractable than trying to give a broad overall description of everything that the model has learned. Doesn't necessarily count as a real answer to the original question, though, "what can you find out about what this model has learned from its training data?"
  • Stepping back: is there an important distinction between "what this model has learned from its training data" vs "how this model's behavior can be explained"? I'm not sure.
    • The former suggests something more like "what are the features that the model has learned", vs eg "why did the model respond to input A with X rather than Y?"
    • The former seems less amenable to explanation-ish approaches like the ones above
  • A core question re 'what this model has learned from its training data' -- it seems like a critical part of that question is: what features has the model learned? It's not immediately obvious to me that we can decompose input/output into the features the model actually learns; maybe we can only pick a set of assumed features and figure out how important those are.
  • It would be really nice if you could do something like 'find all inputs such that they produce dangerous output x', but I don't see a non-exhaustive way to do that -- although maaaybe you could use the gradients (in the final part of the question) to try to find directions that push toward 'more dangerous'. Seems like that would only work if you had a clear way to figure out a direction in some meaningful vector space from arbitrary output x toward some more dangerous output y.
    • Which seems very questionable, both whether you can even come up with a meaningful vector space there in the absence of the model internals, and whether (even if you could) you could usefully define a direction from x to y. And of course how would you even pick what dangerous output to consider?
    • All of this seems MUCH more plausible for something like a numeric regression model with very low-dimensional output. If I think about it in the context of an LM it seems totally implausible.
    • Given more time I'd go do some more reading and thinking about directionality and the exact meaning of gradients in something like a sequence-to-sequence language model.

CENTRALISH POINTS

  • What does the question mean? 'What can you find out about what this model has learned from its training data?' What kinds of things do you WANT to learn?
    • You've fully learned, in principle, everything about what the model does, in an extensional sense. But that's not very useful in practice.
    • We'd like to be able to understand its behavior wrt specific input/output (beyond just 'it produces output Y when given input X). It looks like there are pretty good techniques for doing that (LIME, anchors, prototypes, etc).
    • We'd like to be able to understand its global behavior in a way that's human-meaningful. There are also known techniques for doing this. I think that there's a balance here between accuracy of explanation and human-meaningfulness; any explanation that captures all the behavior of a large model is (almost certainly) going to itself be too complicated to be human-meaningful (although some eg Rudin seem to disagree). That's not specific to black-box / model-agnostic approaches though.
    • We'd like to learn what features it's learned, and I'm not sure we can truly do that (but also not sure that we can't -- something something clustering (and maybe something something PCA for the most important features)). We can however create a set of features that we find useful, and talk about what the model does wrt those.
    • We'd obviously like to be able to answer questions about whether it'll behave in a safe/aligned way. In a certain trivial sense we've got that given the premises (just feed it all the inputs you're worried about and see what the output is!), but of course running the model isn't necessarily a safe way to test what the model does, eg in the case of ASI (and again you can't do it exhaustively in practice). It might be safer to use something like a simpler global surrogate model to try to answer those questions, but we can't be confident in all cases that the surrogate model will give us the correct answer (otherwise we could just replace the original model with the surrogate, and also potentially the surrogate would be unsafe).
    • There's maybe an important conceptual distinction to be drawn between what the model has learned vs how the model behaves. The former is, I think, extremely difficult; for example any misaligned and sufficiently smart model could perfectly simulate the behavior of an aligned model at all times until it believes it's no longer being tested.
      • In some sense, though, if we're not testing the behavior of the model under conditions that convince it it's no longer being tested, we haven't actually tested all input of interest.
  • That last point suggests that I may be putting way too much weight in this analysis on the in-principle exhaustive enumeration of input/output pairs. Because 'we haven't actually tested all input of interest' is always going to be true in the real world. I may need to draw a sharper distinction between "we have tested all cases" and "we've tested every case that seemed important to us" (although time may disallow that because it would involve a lot of rethinking / rewriting).
  • Gradients definitely give us important extra knowledge. First, they tell us which parts of the input are most important in deciding the output (with caveats about non-linear interactions between inputs, I think). And their directionality gives us some knowledge about how to change the inputs in order to have a particular effect on the output -- although that gets complicated for models with high-dimensional outputs like seq2seq language models, in ways that I don't currently have a great understanding of.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment