__init__.py 1.0 KB

123456789101112131415161718192021222324252627282930313233
  1. from pycs_api.views.project import ProjectViewSet
  2. from pycs_api.views.model import ModelViewSet
  3. from pycs_api.views.label_provider import LabelProviderViewSet
  4. from rest_framework import routers
  5. from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
  6. from rest_framework_simplejwt.views import TokenObtainPairView as BaseTokenObtainView
  7. from rest_framework_simplejwt.views import TokenRefreshView as BaseTokenRefreshView
  8. class CustomTokenSerializer(TokenObtainPairSerializer):
  9. def validate(self, attrs):
  10. data = super().validate(attrs)
  11. # Add custom data
  12. data['username'] = self.user.username
  13. return data
  14. class TokenObtainPairView(BaseTokenObtainView):
  15. serializer_class = CustomTokenSerializer
  16. class TokenRefreshView(BaseTokenRefreshView):
  17. # serializer_class = CustomTokenSerializer
  18. pass
  19. router = routers.DefaultRouter()
  20. router.register(r'model', ModelViewSet)
  21. router.register(r'label-provider', LabelProviderViewSet)
  22. router.register(r'project', ProjectViewSet)