Skip to content

Instantly share code, notes, and snippets.

@KexinFeng
Created January 26, 2023 03:11
Show Gist options
  • Save KexinFeng/2ad44daf14f77bc45c7b4c7fd8cf8f33 to your computer and use it in GitHub Desktop.
Save KexinFeng/2ad44daf14f77bc45c7b4c7fd8cf8f33 to your computer and use it in GitHub Desktop.
RCScope
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ndarray;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
/**
* {@link NDResource} objects attach themselves automatically on creation to the first {@link NDScope}
* found in {@link #SCOPE_STACK}.
*
* <p>This class has been derived from {@code org.bytedeco.javacpp.PointerScope} by Samuel Audet
*/
public class NDScope implements AutoCloseable {
private static final ThreadLocal<Deque<NDScope>> SCOPE_STACK =
ThreadLocal.withInitial(ArrayDeque::new);
private List<NDResource> resources;
/**
* Creates a new scope accepting all referenceCountedObject types and pushes itself on the
* {@link #SCOPE_STACK}.
*/
public NDScope() {
resources = new ArrayList<>();
SCOPE_STACK.get().addLast(this);
}
public static void register(NDResource resource) {
Deque<NDScope> queue = SCOPE_STACK.get();
if (queue.isEmpty()) {
return;
}
queue.getLast().resources.add(resource);
}
public void retain(NDResource resource) {
resources.remove(resource);
}
/**
* {@inheritDoc}
*/
@Override
public void close() {
for (NDResource resource : resources) {
resource.close();
}
SCOPE_STACK.get().remove(this);
}
}
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.pytorch.refcount;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.pytorch.engine.PtNDArray;
import org.testng.Assert;
import org.testng.annotations.Test;
public class RCScopeTest {
@Test
@SuppressWarnings("try")
public void testRCScopeDetachingFromManager() {
try (NDManager manager = NDManager.newBaseManager()) {
PtNDArray inside;
try (NDScope ignore = new NDScope()) {
inside = (PtNDArray) manager.create(new int[]{1});
}
Assert.assertTrue(inside.isReleased());
Assert.assertFalse(inside.getManager().hasResource(inside));
Assert.assertFalse(inside.getManager().hasTempResource(inside));
try (NDScope scope = new NDScope()) {
inside = (PtNDArray) manager.create(new int[]{1});
scope.retain(inside);
}
Assert.assertFalse(inside.isReleased());
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment