blob: afc6b7271cc654ef701867c1ab28e3fa469f6a0d [file] [log] [blame]
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
import numpy as np
import torch
import torch.onnx
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical
class LogSoftMaxDecoder(nn.Module):
def __init__(self, temperature):
super(LogSoftMaxDecoder, self).__init__()
self.logSoftMax = nn.LogSoftmax(dim=-1)
self.temperature = temperature
def forward(self, logits: Tensor):
if (self.temperature == 0):
next_tokens = logits.argmax(dim=-1)
else:
next_tokens = Categorical(logits=logits / self.temperature).sample()
logprobs = self.logSoftMax(logits)
return next_tokens, logprobs
# get input output names before exporting
# tokens = torch.randn(1, 51864)
logits = torch.randn(1, 1, 51865)
input_names = [ "logits" ]
output_names = [ "token_ids" ] + [ "probs" ]
model = LogSoftMaxDecoder(0.2)
model.eval()
print(model)
torch.onnx.export(
model,
(logits),
"logSoftMaxdecoder.onnx",
verbose=False,
input_names=["logits"],
output_names=["token_ids", "probs"],
dynamic_axes={
"logits": [2]
}
)